Revert D30561459: Fix bytes_written and bytes_read
authorAlban Desmaison <albandes@fb.com>
Mon, 30 Aug 2021 21:56:35 +0000 (14:56 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 21:59:54 +0000 (14:59 -0700)
Test Plan: revert-hammer

Differential Revision:
D30561459 (https://github.com/pytorch/pytorch/commit/e98173ff3423247c597e21c923c8f47470ef07ab)

Original commit changeset: 976fa5167097

fbshipit-source-id: 43f4c234ca400820fe6db5b4f37a25e14dc4b0dd

caffe2/core/operator_schema.h
caffe2/operators/batch_matmul_op.cc
caffe2/operators/concat_split_op.cc
caffe2/operators/conv_pool_op_base.h
caffe2/operators/distance_op.cc
caffe2/operators/fc_inference.cc
caffe2/operators/one_hot_ops.cc
caffe2/operators/utility_ops.cc
caffe2/python/operator_test/concat_op_cost_test.py
caffe2/python/workspace_test.py
caffe2/sgd/adagrad_op.cc

index 0d048eb..64f5ef3 100644 (file)
@@ -6,13 +6,12 @@
 #include <initializer_list>
 #include <ostream>
 #include <set>
-#include <unordered_map>
 #include <vector>
+#include <unordered_map>
 
 #include "c10/util/Registry.h"
 #include "caffe2/core/common.h"
 #include "caffe2/core/logging.h"
-#include "caffe2/core/types.h"
 #include "caffe2/proto/caffe2_pb.h"
 #include "caffe2/utils/filler.h"
 #include "caffe2/utils/proto_utils.h"
@@ -274,8 +273,8 @@ class TORCH_API OpSchema {
   OpSchema&
   Arg(const char* name, const char* description, bool required = false);
 
-#define DECLARE_STANDARD_ARG(name, str) \
-  static const char* Arg_##name;        \
+#define DECLARE_STANDARD_ARG(name, str)     \
+  static const char* Arg_##name; \
   OpSchema& Arg##name(const char* description);
 
   DECLARE_STANDARD_ARG(IsTest, is_test)
@@ -340,9 +339,7 @@ class TORCH_API OpSchema {
     return inplace_enforced_(x, y);
   }
 
-  TORCH_API friend std::ostream& operator<<(
-      std::ostream& out,
-      const OpSchema& schema);
+  TORCH_API friend std::ostream& operator<<(std::ostream& out, const OpSchema& schema);
 
   const std::vector<Argument>& args() const {
     return args_;
@@ -565,10 +562,8 @@ OpSchema::Cost PointwiseCostInference(
   }
 
   c.flops = nElemX * OpsPerPoint;
-  auto const& X_element_size_byte =
-      DataTypeToTypeMeta(X.data_type()).itemsize();
-  c.bytes_read = nElemRead * X_element_size_byte;
-  c.bytes_written = nElemX * X_element_size_byte;
+  c.bytes_read = nElemRead * sizeof(X.data_type());
+  c.bytes_written = nElemX * sizeof(X.data_type());
   return c;
 }
 
index 205acf7..32799ce 100644 (file)
@@ -1,7 +1,6 @@
 #include "caffe2/operators/batch_matmul_op.h"
 
 #include "caffe2/core/operator_schema.h"
-#include "caffe2/core/types.h"
 
 namespace caffe2 {
 
@@ -117,13 +116,9 @@ OpSchema::Cost CostInferenceForBatchMatMul(
     K = in[0].dims(ndims_A - 1);
   }
 
-  auto const& A_element_size_byte =
-      DataTypeToTypeMeta(A.data_type()).itemsize();
-  auto const& Y_element_size_byte =
-      DataTypeToTypeMeta(Y.data_type()).itemsize();
   c.flops = 2 * nElemY * K;
-  c.bytes_read = (nElemA + nElemB) * A_element_size_byte;
-  c.bytes_written = nElemY * Y_element_size_byte;
+  c.bytes_read = (nElemA + nElemB) * sizeof(A.data_type());
+  c.bytes_written = nElemY * sizeof(Y.data_type());
   c.params_bytes = 0;
   return c;
 }
@@ -185,76 +180,72 @@ class GetBatchMatMulGradient : public GradientMakerBase {
     auto no_trans_arg = vector<Argument>();
     auto trans_a_arg = vector<Argument>{MakeArgument<int>("trans_a", 1)};
     auto trans_b_arg = vector<Argument>{MakeArgument<int>("trans_b", 1)};
-    auto trans_both_arg = vector<Argument>{
-        MakeArgument<int>("trans_a", 1), MakeArgument<int>("trans_b", 1)};
+    auto trans_both_arg = vector<Argument>{MakeArgument<int>("trans_a", 1),
+                                           MakeArgument<int>("trans_b", 1)};
 
     if (trans_a) {
       if (trans_b) {
         // A'B':
         // dA = B'G', dB = G'A'
-        return vector<OperatorDef>{
-            CreateOperatorDef(
-                "BatchMatMul",
-                "",
-                vector<string>{I(1), GO(0)},
-                vector<string>{GI(0)},
-                trans_both_arg),
-            CreateOperatorDef(
-                "BatchMatMul",
-                "",
-                vector<string>{GO(0), I(0)},
-                vector<string>{GI(1)},
-                trans_both_arg)};
+        return vector<OperatorDef>{CreateOperatorDef(
+                                       "BatchMatMul",
+                                       "",
+                                       vector<string>{I(1), GO(0)},
+                                       vector<string>{GI(0)},
+                                       trans_both_arg),
+                                   CreateOperatorDef(
+                                       "BatchMatMul",
+                                       "",
+                                       vector<string>{GO(0), I(0)},
+                                       vector<string>{GI(1)},
+                                       trans_both_arg)};
       } else {
         // A'B:
         // dA = BG', dB = AG
-        return vector<OperatorDef>{
-            CreateOperatorDef(
-                "BatchMatMul",
-                "",
-                vector<string>{I(1), GO(0)},
-                vector<string>{GI(0)},
-                trans_b_arg),
-            CreateOperatorDef(
-                "BatchMatMul",
-                "",
-                vector<string>{I(0), GO(0)},
-                vector<string>{GI(1)},
-                no_trans_arg)};
+        return vector<OperatorDef>{CreateOperatorDef(
+                                       "BatchMatMul",
+                                       "",
+                                       vector<string>{I(1), GO(0)},
+                                       vector<string>{GI(0)},
+                                       trans_b_arg),
+                                   CreateOperatorDef(
+                                       "BatchMatMul",
+                                       "",
+                                       vector<string>{I(0), GO(0)},
+                                       vector<string>{GI(1)},
+                                       no_trans_arg)};
       }
     } else {
       if (trans_b) {
         // AB':
         // dA = GB, dB = G'A
-        return vector<OperatorDef>{
-            CreateOperatorDef(
-                "BatchMatMul",
-                "",
-                vector<string>{GO(0), I(1)},
-                vector<string>{GI(0)},
-                no_trans_arg),
-            CreateOperatorDef(
-                "BatchMatMul",
-                "",
-                vector<string>{GO(0), I(0)},
-                vector<string>{GI(1)},
-                trans_a_arg)};
+        return vector<OperatorDef>{CreateOperatorDef(
+                                       "BatchMatMul",
+                                       "",
+                                       vector<string>{GO(0), I(1)},
+                                       vector<string>{GI(0)},
+                                       no_trans_arg),
+                                   CreateOperatorDef(
+                                       "BatchMatMul",
+                                       "",
+                                       vector<string>{GO(0), I(0)},
+                                       vector<string>{GI(1)},
+                                       trans_a_arg)};
       } else {
         // AB:
         // dA = GB', dB = A'G
-        return vector<OperatorDef>{
-            CreateOperatorDef(
-                "BatchMatMul",
-                "",
-                vector<string>{GO(0), I(1)},
-                vector<string>{GI(0)},
-                trans_b_arg),
-            CreateOperatorDef(
-                "BatchMatMul",
-                "",
-                vector<string>{I(0), GO(0)},
-                vector<string>{GI(1)},
-                trans_a_arg)};
+        return vector<OperatorDef>{CreateOperatorDef(
+                                       "BatchMatMul",
+                                       "",
+                                       vector<string>{GO(0), I(1)},
+                                       vector<string>{GI(0)},
+                                       trans_b_arg),
+                                   CreateOperatorDef(
+                                       "BatchMatMul",
+                                       "",
+                                       vector<string>{I(0), GO(0)},
+                                       vector<string>{GI(1)},
+                                       trans_a_arg)};
       }
     }
   }
index 8aa9e28..8eceb5a 100644 (file)
@@ -101,12 +101,9 @@ OpSchema::Cost CostInferenceForSplit(
   CAFFE_ENFORCE_GT(in.size(), 0);
   struct OpSchema::Cost cost;
   cost.flops = 0;
-  auto const& input_0_element_size_byte =
-      DataTypeToTypeMeta(in[0].data_type()).itemsize();
-  auto const& input_1_element_size_byte =
-      (in.size() > 1) ? DataTypeToTypeMeta(in[1].data_type()).itemsize() : 0;
-  auto input_bytes_count = nElemFromDim(in[0]) * input_0_element_size_byte;
-  auto split_bytes_count = nElemFromDim(in[1]) * input_1_element_size_byte;
+  auto input_bytes_count = nElemFromDim(in[0]) * sizeof(in[0].data_type());
+  auto split_bytes_count =
+      (in.size() == 1) ? 0 : nElemFromDim(in[1]) * sizeof(in[1].data_type());
   // There can be two input blobs:
   // (1) actual tensor to be split
   // (2) lengths of outputs along split axis
@@ -332,13 +329,11 @@ OpSchema::Cost CostInferenceForConcat(
   }
   auto split_info_bytes_count = in.size() * sizeof(int);
 
-  auto const& input_0_element_size_byte =
-      DataTypeToTypeMeta(in[0].data_type()).itemsize();
   struct OpSchema::Cost cost;
   cost.flops = 0;
-  cost.bytes_read = nElemRead * input_0_element_size_byte;
+  cost.bytes_read = nElemRead * sizeof(in[0].data_type());
   cost.bytes_written =
-      size * input_0_element_size_byte + split_info_bytes_count;
+      size * sizeof(in[0].data_type()) + split_info_bytes_count;
   cost.params_bytes = 0;
   return cost;
 }
index b356ef9..25bd99a 100644 (file)
@@ -7,7 +7,6 @@
 #include "caffe2/core/context.h"
 #include "caffe2/core/logging.h"
 #include "caffe2/core/operator.h"
-#include "caffe2/core/types.h"
 #include "caffe2/proto/caffe2_legacy.pb.h"
 #include "caffe2/utils/math.h"
 
@@ -520,20 +519,14 @@ class ConvPoolOpBase : public Operator<Context> {
     uint64_t nElemW = nElemFromDim(W);
     uint64_t nElemBias = inputs.size() > 2 ? nElemFromDim(inputs[2]) : 0;
 
-    auto const& X_elemenet_size_byte =
-        DataTypeToTypeMeta(X.data_type()).itemsize();
-    auto const& Y_element_size_byte =
-        DataTypeToTypeMeta(Y.data_type()).itemsize();
-    auto const& W_element_size_byte =
-        DataTypeToTypeMeta(W.data_type()).itemsize();
-
     // grouping is NOT properly handled yet
     c.flops = N * Y_t * Y_h * Y_w * kernel_t * kernel_w * kernel_h *
         in_channels * out_channels * 2;
-    c.bytes_read = (nElemX + nElemW + nElemBias) * X_elemenet_size_byte;
-    c.bytes_written = N * out_channels * Y_t * Y_h * Y_w * Y_element_size_byte;
+    c.bytes_read = (nElemX + nElemW + nElemBias) * sizeof(X.data_type());
+    c.bytes_written =
+        N * out_channels * Y_t * Y_h * Y_w * sizeof(Y.data_type());
     c.params_bytes = out_channels * in_channels * kernel_t * kernel_h *
-        kernel_w * W_element_size_byte;
+        kernel_w * sizeof(W.data_type());
     return c;
   }
 
index 9ea8eea..1529534 100644 (file)
@@ -1,5 +1,4 @@
 #include "caffe2/operators/distance_op.h"
-#include "caffe2/core/types.h"
 #include "caffe2/utils/eigen_utils.h"
 #ifdef CAFFE2_USE_MKLDNN
 #include <caffe2/ideep/operators/operator_fallback_ideep.h>
@@ -8,7 +7,7 @@
 
 namespace caffe2 {
 
-template <>
+template<>
 bool SquaredL2DistanceOp<float, CPUContext>::RunOnDevice() {
   auto& X = Input(0);
   auto& Y = Input(1);
@@ -258,9 +257,7 @@ OpSchema::Cost CostInferenceForDotProduct(
   CAFFE_ENFORCE_EQ(out[0].dims().size(), 1);
 
   struct OpSchema::Cost c = PointwiseCostInference<2>(def, in);
-  auto const& out_0_element_size_byte =
-      DataTypeToTypeMeta(out[0].data_type()).itemsize();
-  c.bytes_written = out[0].dims(0) * out_0_element_size_byte;
+  c.bytes_written = out[0].dims(0) * sizeof(out[0].data_type());
   c.params_bytes = 0;
   return c;
 }
@@ -382,12 +379,10 @@ bool DotProductWithPaddingOp<float, CPUContext>::RunOnDevice() {
 }
 
 // L2
-REGISTER_CPU_OPERATOR(
-    SquaredL2Distance,
-    SquaredL2DistanceOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(
-    SquaredL2DistanceGradient,
-    SquaredL2DistanceGradientOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(SquaredL2Distance,
+                      SquaredL2DistanceOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(SquaredL2DistanceGradient,
+                      SquaredL2DistanceGradientOp<float, CPUContext>);
 
 OPERATOR_SCHEMA(SquaredL2Distance)
     .NumInputs(2)
@@ -407,8 +402,7 @@ class GetSquaredL2DistanceGradient : public GradientMakerBase {
   using GradientMakerBase::GradientMakerBase;
   vector<OperatorDef> GetGradientDefs() override {
     return SingleGradientDef(
-        "SquaredL2DistanceGradient",
-        "",
+        "SquaredL2DistanceGradient", "",
         vector<string>{I(0), I(1), GO(0)},
         vector<string>{GI(0), GI(1)});
   }
@@ -768,9 +762,9 @@ class GetDotProductWithPaddingGradient : public GradientMakerBase {
       replicate = GetArgument(Def(), "replicate").i();
     }
 
-    const auto dot_arg = vector<Argument>{
-        MakeArgument<float>("pad_value", pad_value),
-        MakeArgument<bool>("replicate", replicate)};
+    const auto dot_arg =
+        vector<Argument>{MakeArgument<float>("pad_value", pad_value),
+                         MakeArgument<bool>("replicate", replicate)};
 
     return SingleGradientDef(
         "DotProductWithPaddingGradient",
@@ -781,4 +775,4 @@ class GetDotProductWithPaddingGradient : public GradientMakerBase {
   }
 };
 REGISTER_GRADIENT(DotProductWithPadding, GetDotProductWithPaddingGradient);
-} // namespace caffe2
+}  // namespace caffe2
index ba1b712..a44c230 100644 (file)
@@ -1,5 +1,4 @@
 #include "caffe2/operators/fc_inference.h"
-#include "caffe2/core/types.h"
 
 namespace caffe2 {
 std::vector<TensorShape> FCShapeInference(
@@ -52,12 +51,11 @@ OpSchema::Cost CostInferenceForFC(
       ? size_from_dim_(canonical_axis_w, GetDimsVector(in[1]))
       : size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
 
-  auto const& X_element_size_byte =
-      DataTypeToTypeMeta(in[0].data_type()).itemsize();
+  const auto& X = in[0];
   c.flops = M * N * (2 * K + 1);
-  c.bytes_read = (K * (M + N) + N) * X_element_size_byte;
-  c.bytes_written = M * N * X_element_size_byte;
-  c.params_bytes = (K * N + N) * X_element_size_byte;
+  c.bytes_read = (K * (M + N) + N) * sizeof(X.data_type());
+  c.bytes_written = M * N * sizeof(X.data_type());
+  c.params_bytes = (K * N + N) * sizeof(X.data_type());
   return c;
 }
 
@@ -96,11 +94,7 @@ OpSchema::Cost CostInferenceForFCGradient(
 
   CAFFE_ENFORCE_LT(0, out.size());
   const TensorShape dW = out[0];
-  auto const& dW_element_size_byte =
-      DataTypeToTypeMeta(dW.data_type()).itemsize();
   const TensorShape db = out[1];
-  auto const& db_element_size_byte =
-      DataTypeToTypeMeta(db.data_type()).itemsize();
 
   auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
   const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
@@ -117,17 +111,15 @@ OpSchema::Cost CostInferenceForFCGradient(
   uint64_t size_db = nElemFromDim(db);
 
   c.flops = M * N * (2 * K + 1);
-  c.bytes_written =
-      size_dW * dW_element_size_byte + size_db * db_element_size_byte;
+  c.bytes_written = (size_dW + size_db) * sizeof(float);
   c.params_bytes = (K * N + N) * sizeof(float);
 
   if (out.size() == 3) {
     const TensorShape dX = out[2];
     uint64_t size_dX = nElemFromDim(dX);
-    auto const& dX_element_size_byte =
-        DataTypeToTypeMeta(dX.data_type()).itemsize();
+
     c.flops += 2 * M * N * K;
-    c.bytes_written += size_dX * dX_element_size_byte;
+    c.bytes_written += size_dX * sizeof(float);
   }
   return c;
 }
index 55c73a5..c3eaf05 100644 (file)
@@ -2,7 +2,6 @@
 
 #include "caffe2/core/operator.h"
 #include "caffe2/core/tensor.h"
-#include "caffe2/core/types.h"
 
 namespace caffe2 {
 
@@ -79,21 +78,12 @@ OpSchema::Cost CostInferenceForBatchOneHot(
   const auto& length = in[1];
   const auto& values = in[2];
 
-  auto const& data_element_size_byte =
-      DataTypeToTypeMeta(data.data_type()).itemsize();
-  auto const& length_element_size_byte =
-      DataTypeToTypeMeta(length.data_type()).itemsize();
-  auto const& values_element_size_byte =
-      DataTypeToTypeMeta(values.data_type()).itemsize();
-  auto const& output_element_size_byte =
-      DataTypeToTypeMeta(output.data_type()).itemsize();
-
-  uint64_t nBytesData = nElemFromDim(data) * data_element_size_byte;
-  uint64_t nBytesLength = nElemFromDim(length) * length_element_size_byte;
-  uint64_t nBytesValues = nElemFromDim(values) * values_element_size_byte;
+  uint64_t nBytesData = nElemFromDim(data) * sizeof(data.data_type());
+  uint64_t nBytesLength = nElemFromDim(length) * sizeof(length.data_type());
+  uint64_t nBytesValues = nElemFromDim(values) * sizeof(values.data_type());
   c.flops = 0;
   c.bytes_read = nBytesData + nBytesLength + nBytesValues;
-  c.bytes_written = nElemFromDim(output) * output_element_size_byte;
+  c.bytes_written = nElemFromDim(output) * sizeof(output.data_type());
   c.params_bytes = 0;
   return c;
 }
@@ -155,15 +145,15 @@ bool BatchBucketOneHotOp<CPUContext>::RunOnDevice() {
     for (int64_t j = 0; j < D; j++) {
       // here we assume the boundary values for each feature are sorted
       int64_t lower_bucket_idx = std::lower_bound(
-                                     boundaries_offset,
-                                     boundaries_offset + lens_data[j],
-                                     input_data[pos]) -
+                                    boundaries_offset,
+                                    boundaries_offset + lens_data[j],
+                                    input_data[pos]) -
           boundaries_offset;
 
       int64_t upper_bucket_idx = std::upper_bound(
-                                     boundaries_offset,
-                                     boundaries_offset + lens_data[j],
-                                     input_data[pos]) -
+                                    boundaries_offset,
+                                    boundaries_offset + lens_data[j],
+                                    input_data[pos]) -
           boundaries_offset;
 
       int64_t bucket_idx = (lower_bucket_idx + upper_bucket_idx) / 2;
index 561da91..8b5e116 100644 (file)
@@ -1,7 +1,6 @@
 #include "caffe2/operators/utility_ops.h"
 #include <cmath>
 #include <iostream>
-#include "caffe2/core/types.h"
 #include "caffe2/utils/eigen_utils.h"
 
 namespace caffe2 {
@@ -35,11 +34,9 @@ OpSchema::Cost CostInferenceForWeightedSum(
   const auto& nElem = nElemFromDim(X0);
   const auto& nInputs = in.size();
   c.flops = (nInputs - 1) * nElem;
-  auto const& X0_element_size_byte =
-      DataTypeToTypeMeta(X0.data_type()).itemsize();
-  c.bytes_read = (nInputs / 2) * (nElem + 1) * X0_element_size_byte;
-  c.bytes_written = nElem * X0_element_size_byte;
-  c.params_bytes = (nInputs / 2) * X0_element_size_byte;
+  c.bytes_read = (nInputs / 2) * (nElem + 1) * sizeof(X0.data_type());
+  c.bytes_written = nElem * sizeof(X0.data_type());
+  c.params_bytes = (nInputs / 2) * sizeof(X0.data_type());
   return c;
 }
 
@@ -51,7 +48,9 @@ REGISTER_CPU_OPERATOR(ResizeLike, ResizeLikeOp<CPUContext>);
 REGISTER_CPU_OPERATOR(SumInt, SumOp<CPUContext>);
 REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp<CPUContext>);
 REGISTER_CPU_OPERATOR(WeightedSumGradient, WeightedSumGradientOp<CPUContext>);
-REGISTER_CPU_OPERATOR(ScatterWeightedSum, ScatterWeightedSumOp<CPUContext>);
+REGISTER_CPU_OPERATOR(
+    ScatterWeightedSum,
+    ScatterWeightedSumOp<CPUContext>);
 REGISTER_CPU_OPERATOR(ScatterAssign, ScatterAssignOp<CPUContext>);
 REGISTER_CPU_OPERATOR(Scatter, ScatterOp<CPUContext>);
 
index 7dab4d6..996b330 100644 (file)
@@ -7,39 +7,33 @@ from caffe2.python.test_util import TestCase
 
 class TestConcatOpCost(TestCase):
     def test_columnwise_concat(self):
-        def _test_columnwise_concat_for_type(dtype):
-            workspace.ResetWorkspace()
-            workspace.FeedBlob("input_1", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
-            workspace.FeedBlob("input_2", np.array([[7], [8]], dtype=dtype))
-            concat_op = core.CreateOperator(
-                "Concat",
-                ["input_1", "input_2"],
-                ["output", "split_info"],
-            )
-            workspace.RunOperatorOnce(concat_op)
-
-            output = workspace.FetchBlob("output")
-            self.assertTupleEqual(output.shape, (2, 4))
-            np.testing.assert_array_equal(output, [[1, 2, 3, 7], [4, 5, 6, 8]])
+        workspace.ResetWorkspace()
+        workspace.FeedBlob("input_1", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
+        workspace.FeedBlob("input_2", np.array([[7], [8]], dtype=np.int32))
+        concat_op = core.CreateOperator(
+            "Concat",
+            ["input_1", "input_2"],
+            ["output", "split_info"],
+        )
+        workspace.RunOperatorOnce(concat_op)
 
-            flops, bytes_written, bytes_read = workspace.GetOperatorCost(
-                concat_op, concat_op.input
-            )
+        output = workspace.FetchBlob("output")
+        self.assertTupleEqual(output.shape, (2, 4))
+        np.testing.assert_array_equal(output, [[1, 2, 3, 7], [4, 5, 6, 8]])
 
-            self.assertEqual(flops, 0)
-            self.assertEqual(
-                bytes_read,
-                sum(workspace.FetchBlob(b).nbytes for b in concat_op.input),
-            )
-            self.assertEqual(
-                bytes_written,
-                sum(workspace.FetchBlob(b).nbytes for b in concat_op.output),
-            )
+        flops, bytes_written, bytes_read = workspace.GetOperatorCost(
+            concat_op, concat_op.input
+        )
 
-        [
-            _test_columnwise_concat_for_type(t)
-            for t in [np.int64, np.float, np.half, np.int8]
-        ]
+        self.assertEqual(flops, 0)
+        self.assertEqual(
+            bytes_read,
+            sum(workspace.FetchBlob(b).nbytes for b in concat_op.input),
+        )
+        self.assertEqual(
+            bytes_written,
+            sum(workspace.FetchBlob(b).nbytes for b in concat_op.output),
+        )
 
     def test_split_then_concat(self):
         workspace.ResetWorkspace()
index 1bf7b60..afb2065 100644 (file)
@@ -60,7 +60,7 @@ class TestWorkspace(unittest.TestCase):
         self.assertTupleEqual(
             op_cost,
             namedtuple("Cost", ["flops", "bytes_written", "bytes_read"])(
-                1152, 256, 4168
+                1152, 256, 2084
             ),
         )
 
index 0b6f604..0de50f0 100644 (file)
@@ -1,5 +1,4 @@
 #include "adagrad_op.h"
-#include "caffe2/core/types.h"
 
 namespace caffe2 {
 
@@ -24,30 +23,22 @@ static OpSchema::Cost CostInferenceForAdagrad(
   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
   c.flops = grad_size * 10;
 
-  auto const& moment_element_size_byte =
-      DataTypeToTypeMeta(moment.data_type()).itemsize();
-  auto const& param_element_size_byte =
-      DataTypeToTypeMeta(param.data_type()).itemsize();
-  auto const& grad_element_size_byte =
-      DataTypeToTypeMeta(grad.data_type()).itemsize();
-  auto const& lr_element_size_byte =
-      DataTypeToTypeMeta(lr.data_type()).itemsize();
   uint64_t bytes_written =
-      grad_size * param_element_size_byte + moment_element_size_byte;
+      grad_size * (sizeof(param.data_type()) + sizeof(moment.data_type()));
 
   if (output_size == 3) {
     // also need to output effective learning rate in this case
     // assume it's the same data type as lr
-    bytes_written += grad_size * lr_element_size_byte;
+    bytes_written += grad_size * sizeof(lr.data_type());
   } else if (output_size == 4) {
     // also need to output effective learning rate and updates in this case
     // assume update is the same data type as param
     bytes_written +=
-        grad_size * (lr_element_size_byte + param_element_size_byte);
+        grad_size * (sizeof(lr.data_type()) + sizeof(param.data_type()));
   }
   c.bytes_written = bytes_written;
   c.bytes_read = c.bytes_written +
-      grad_size * (grad_element_size_byte + lr_element_size_byte);
+      grad_size * (sizeof(grad.data_type()) + sizeof(lr.data_type()));
 
   return c;
 }
@@ -111,18 +102,10 @@ static OpSchema::Cost CostInferenceForSparseAdagrad(
   // (optimistically count sqrt as one flop).
   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
   c.flops = grad_size * 7;
-  auto const& param_element_size_byte =
-      DataTypeToTypeMeta(param.data_type()).itemsize();
-  auto const& moment_element_size_byte =
-      DataTypeToTypeMeta(moment.data_type()).itemsize();
   c.bytes_written =
-      grad_size * (param_element_size_byte + moment_element_size_byte);
-  auto const& grad_element_size_byte =
-      DataTypeToTypeMeta(grad.data_type()).itemsize();
-  auto const& indices_element_size_byte =
-      DataTypeToTypeMeta(indices.data_type()).itemsize();
-  c.bytes_read = c.bytes_written + grad_size * grad_element_size_byte +
-      n * indices_element_size_byte;
+      grad_size * (sizeof(param.data_type()) + sizeof(moment.data_type()));
+  c.bytes_read = c.bytes_written + grad_size * sizeof(grad.data_type()) +
+      n * sizeof(indices.data_type());
 
   return c;
 }
@@ -170,16 +153,6 @@ static OpSchema::Cost CostInferenceForRowWiseSparseAdagrad(
   OpSchema::Cost c;
 
   if (n > 0) {
-    auto const& param_element_size_byte =
-        DataTypeToTypeMeta(param.data_type()).itemsize();
-    auto const& moment_element_size_byte =
-        DataTypeToTypeMeta(moment.data_type()).itemsize();
-    auto const& grad_element_size_byte =
-        DataTypeToTypeMeta(grad.data_type()).itemsize();
-    auto const& indices_element_size_byte =
-        DataTypeToTypeMeta(indices.data_type()).itemsize();
-    auto const& lr_element_size_byte =
-        DataTypeToTypeMeta(lr.data_type()).itemsize();
     auto block_size = grad_size / n;
     if (block_size == 1) {
       // +2: applying weight decay and add to grads
@@ -188,22 +161,22 @@ static OpSchema::Cost CostInferenceForRowWiseSparseAdagrad(
       // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
       c.flops = n * 9;
       c.bytes_written =
-          n * (param_element_size_byte + moment_element_size_byte);
+          n * (sizeof(param.data_type()) + sizeof(moment.data_type()));
       c.bytes_read = c.bytes_written +
           n *
-              (grad_element_size_byte + indices_element_size_byte +
-               lr_element_size_byte);
+              (sizeof(grad.data_type()) + sizeof(indices.data_type()) +
+               sizeof(lr.data_type()));
     } else {
       // 5 per block (not counting index transforms)
       // 8 for each value of a block
       // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
       c.flops = n * (5 + (block_size * 8));
-      c.bytes_written = n * moment_element_size_byte +
-          n * block_size * param_element_size_byte;
+      c.bytes_written =
+          n * sizeof(moment.data_type()) + n * block_size * (param.data_type());
 
-      c.bytes_read = c.bytes_written + n * lr_element_size_byte +
+      c.bytes_read = c.bytes_written + n * (sizeof(lr.data_type())) +
           2 * n * block_size *
-              (grad_element_size_byte + param_element_size_byte);
+              (sizeof(grad.data_type()) + sizeof(param.data_type()));
     }
   }
   return c;