From e98173ff3423247c597e21c923c8f47470ef07ab Mon Sep 17 00:00:00 2001 From: Tanvir Zaman Date: Mon, 30 Aug 2021 12:56:15 -0700 Subject: [PATCH] Fix bytes_written and bytes_read (#64040) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64040 In operator cost inference functions, in many places we are using sizeof(x.data_type()). Since data_type() returns a 32 bit integer from [this enum](https://www.internalfb.com/code/fbsource/[15e7ffe4073cf08c61077c7c24a4839504b964a2]/fbcode/caffe2/caffe2/proto/caffe2.proto?lines=20), we are basically always getting 4 for sizeof(x.data_type()) no matter what actual data type x has. Big thanks to Jack Langman for specifically pointing to this bug. We would instead use the size in bytes based on actual data type. Test Plan: Added unit tests BatchMatMulMemCostTest: buck test //caffe2/caffe2/fb/fbgemm:batch_matmul_op_test -- BatchMatMulMemCostTest Extended existing unit test test_columnwise_concat for different data types: buck test //caffe2/caffe2/python/operator_test:concat_op_cost_test -- test_columnwise_concat Differential Revision: D30561459 fbshipit-source-id: 976fa5167097a35af548498480001aafd7851d93 --- caffe2/core/operator_schema.h | 17 ++-- caffe2/operators/batch_matmul_op.cc | 113 +++++++++++---------- caffe2/operators/concat_split_op.cc | 15 ++- caffe2/operators/conv_pool_op_base.h | 15 ++- caffe2/operators/distance_op.cc | 28 +++-- caffe2/operators/fc_inference.cc | 22 ++-- caffe2/operators/one_hot_ops.cc | 30 ++++-- caffe2/operators/utility_ops.cc | 13 +-- caffe2/python/operator_test/concat_op_cost_test.py | 54 +++++----- caffe2/python/workspace_test.py | 2 +- caffe2/sgd/adagrad_op.cc | 55 +++++++--- 11 files changed, 224 insertions(+), 140 deletions(-) diff --git a/caffe2/core/operator_schema.h b/caffe2/core/operator_schema.h index 64f5ef3..0d048eb 100644 --- a/caffe2/core/operator_schema.h +++ b/caffe2/core/operator_schema.h @@ -6,12 +6,13 @@ #include #include #include -#include #include +#include #include "c10/util/Registry.h" #include "caffe2/core/common.h" #include "caffe2/core/logging.h" +#include "caffe2/core/types.h" #include "caffe2/proto/caffe2_pb.h" #include "caffe2/utils/filler.h" #include "caffe2/utils/proto_utils.h" @@ -273,8 +274,8 @@ class TORCH_API OpSchema { OpSchema& Arg(const char* name, const char* description, bool required = false); -#define DECLARE_STANDARD_ARG(name, str) \ - static const char* Arg_##name; \ +#define DECLARE_STANDARD_ARG(name, str) \ + static const char* Arg_##name; \ OpSchema& Arg##name(const char* description); DECLARE_STANDARD_ARG(IsTest, is_test) @@ -339,7 +340,9 @@ class TORCH_API OpSchema { return inplace_enforced_(x, y); } - TORCH_API friend std::ostream& operator<<(std::ostream& out, const OpSchema& schema); + TORCH_API friend std::ostream& operator<<( + std::ostream& out, + const OpSchema& schema); const std::vector& args() const { return args_; @@ -562,8 +565,10 @@ OpSchema::Cost PointwiseCostInference( } c.flops = nElemX * OpsPerPoint; - c.bytes_read = nElemRead * sizeof(X.data_type()); - c.bytes_written = nElemX * sizeof(X.data_type()); + auto const& X_element_size_byte = + DataTypeToTypeMeta(X.data_type()).itemsize(); + c.bytes_read = nElemRead * X_element_size_byte; + c.bytes_written = nElemX * X_element_size_byte; return c; } diff --git a/caffe2/operators/batch_matmul_op.cc b/caffe2/operators/batch_matmul_op.cc index 32799ce..205acf7 100644 --- a/caffe2/operators/batch_matmul_op.cc +++ b/caffe2/operators/batch_matmul_op.cc @@ -1,6 +1,7 @@ #include "caffe2/operators/batch_matmul_op.h" #include "caffe2/core/operator_schema.h" +#include "caffe2/core/types.h" namespace caffe2 { @@ -116,9 +117,13 @@ OpSchema::Cost CostInferenceForBatchMatMul( K = in[0].dims(ndims_A - 1); } + auto const& A_element_size_byte = + DataTypeToTypeMeta(A.data_type()).itemsize(); + auto const& Y_element_size_byte = + DataTypeToTypeMeta(Y.data_type()).itemsize(); c.flops = 2 * nElemY * K; - c.bytes_read = (nElemA + nElemB) * sizeof(A.data_type()); - c.bytes_written = nElemY * sizeof(Y.data_type()); + c.bytes_read = (nElemA + nElemB) * A_element_size_byte; + c.bytes_written = nElemY * Y_element_size_byte; c.params_bytes = 0; return c; } @@ -180,72 +185,76 @@ class GetBatchMatMulGradient : public GradientMakerBase { auto no_trans_arg = vector(); auto trans_a_arg = vector{MakeArgument("trans_a", 1)}; auto trans_b_arg = vector{MakeArgument("trans_b", 1)}; - auto trans_both_arg = vector{MakeArgument("trans_a", 1), - MakeArgument("trans_b", 1)}; + auto trans_both_arg = vector{ + MakeArgument("trans_a", 1), MakeArgument("trans_b", 1)}; if (trans_a) { if (trans_b) { // A'B': // dA = B'G', dB = G'A' - return vector{CreateOperatorDef( - "BatchMatMul", - "", - vector{I(1), GO(0)}, - vector{GI(0)}, - trans_both_arg), - CreateOperatorDef( - "BatchMatMul", - "", - vector{GO(0), I(0)}, - vector{GI(1)}, - trans_both_arg)}; + return vector{ + CreateOperatorDef( + "BatchMatMul", + "", + vector{I(1), GO(0)}, + vector{GI(0)}, + trans_both_arg), + CreateOperatorDef( + "BatchMatMul", + "", + vector{GO(0), I(0)}, + vector{GI(1)}, + trans_both_arg)}; } else { // A'B: // dA = BG', dB = AG - return vector{CreateOperatorDef( - "BatchMatMul", - "", - vector{I(1), GO(0)}, - vector{GI(0)}, - trans_b_arg), - CreateOperatorDef( - "BatchMatMul", - "", - vector{I(0), GO(0)}, - vector{GI(1)}, - no_trans_arg)}; + return vector{ + CreateOperatorDef( + "BatchMatMul", + "", + vector{I(1), GO(0)}, + vector{GI(0)}, + trans_b_arg), + CreateOperatorDef( + "BatchMatMul", + "", + vector{I(0), GO(0)}, + vector{GI(1)}, + no_trans_arg)}; } } else { if (trans_b) { // AB': // dA = GB, dB = G'A - return vector{CreateOperatorDef( - "BatchMatMul", - "", - vector{GO(0), I(1)}, - vector{GI(0)}, - no_trans_arg), - CreateOperatorDef( - "BatchMatMul", - "", - vector{GO(0), I(0)}, - vector{GI(1)}, - trans_a_arg)}; + return vector{ + CreateOperatorDef( + "BatchMatMul", + "", + vector{GO(0), I(1)}, + vector{GI(0)}, + no_trans_arg), + CreateOperatorDef( + "BatchMatMul", + "", + vector{GO(0), I(0)}, + vector{GI(1)}, + trans_a_arg)}; } else { // AB: // dA = GB', dB = A'G - return vector{CreateOperatorDef( - "BatchMatMul", - "", - vector{GO(0), I(1)}, - vector{GI(0)}, - trans_b_arg), - CreateOperatorDef( - "BatchMatMul", - "", - vector{I(0), GO(0)}, - vector{GI(1)}, - trans_a_arg)}; + return vector{ + CreateOperatorDef( + "BatchMatMul", + "", + vector{GO(0), I(1)}, + vector{GI(0)}, + trans_b_arg), + CreateOperatorDef( + "BatchMatMul", + "", + vector{I(0), GO(0)}, + vector{GI(1)}, + trans_a_arg)}; } } } diff --git a/caffe2/operators/concat_split_op.cc b/caffe2/operators/concat_split_op.cc index 8eceb5a..8aa9e28 100644 --- a/caffe2/operators/concat_split_op.cc +++ b/caffe2/operators/concat_split_op.cc @@ -101,9 +101,12 @@ OpSchema::Cost CostInferenceForSplit( CAFFE_ENFORCE_GT(in.size(), 0); struct OpSchema::Cost cost; cost.flops = 0; - auto input_bytes_count = nElemFromDim(in[0]) * sizeof(in[0].data_type()); - auto split_bytes_count = - (in.size() == 1) ? 0 : nElemFromDim(in[1]) * sizeof(in[1].data_type()); + auto const& input_0_element_size_byte = + DataTypeToTypeMeta(in[0].data_type()).itemsize(); + auto const& input_1_element_size_byte = + (in.size() > 1) ? DataTypeToTypeMeta(in[1].data_type()).itemsize() : 0; + auto input_bytes_count = nElemFromDim(in[0]) * input_0_element_size_byte; + auto split_bytes_count = nElemFromDim(in[1]) * input_1_element_size_byte; // There can be two input blobs: // (1) actual tensor to be split // (2) lengths of outputs along split axis @@ -329,11 +332,13 @@ OpSchema::Cost CostInferenceForConcat( } auto split_info_bytes_count = in.size() * sizeof(int); + auto const& input_0_element_size_byte = + DataTypeToTypeMeta(in[0].data_type()).itemsize(); struct OpSchema::Cost cost; cost.flops = 0; - cost.bytes_read = nElemRead * sizeof(in[0].data_type()); + cost.bytes_read = nElemRead * input_0_element_size_byte; cost.bytes_written = - size * sizeof(in[0].data_type()) + split_info_bytes_count; + size * input_0_element_size_byte + split_info_bytes_count; cost.params_bytes = 0; return cost; } diff --git a/caffe2/operators/conv_pool_op_base.h b/caffe2/operators/conv_pool_op_base.h index 25bd99a..b356ef9 100644 --- a/caffe2/operators/conv_pool_op_base.h +++ b/caffe2/operators/conv_pool_op_base.h @@ -7,6 +7,7 @@ #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" +#include "caffe2/core/types.h" #include "caffe2/proto/caffe2_legacy.pb.h" #include "caffe2/utils/math.h" @@ -519,14 +520,20 @@ class ConvPoolOpBase : public Operator { uint64_t nElemW = nElemFromDim(W); uint64_t nElemBias = inputs.size() > 2 ? nElemFromDim(inputs[2]) : 0; + auto const& X_elemenet_size_byte = + DataTypeToTypeMeta(X.data_type()).itemsize(); + auto const& Y_element_size_byte = + DataTypeToTypeMeta(Y.data_type()).itemsize(); + auto const& W_element_size_byte = + DataTypeToTypeMeta(W.data_type()).itemsize(); + // grouping is NOT properly handled yet c.flops = N * Y_t * Y_h * Y_w * kernel_t * kernel_w * kernel_h * in_channels * out_channels * 2; - c.bytes_read = (nElemX + nElemW + nElemBias) * sizeof(X.data_type()); - c.bytes_written = - N * out_channels * Y_t * Y_h * Y_w * sizeof(Y.data_type()); + c.bytes_read = (nElemX + nElemW + nElemBias) * X_elemenet_size_byte; + c.bytes_written = N * out_channels * Y_t * Y_h * Y_w * Y_element_size_byte; c.params_bytes = out_channels * in_channels * kernel_t * kernel_h * - kernel_w * sizeof(W.data_type()); + kernel_w * W_element_size_byte; return c; } diff --git a/caffe2/operators/distance_op.cc b/caffe2/operators/distance_op.cc index 1529534..9ea8eea 100644 --- a/caffe2/operators/distance_op.cc +++ b/caffe2/operators/distance_op.cc @@ -1,4 +1,5 @@ #include "caffe2/operators/distance_op.h" +#include "caffe2/core/types.h" #include "caffe2/utils/eigen_utils.h" #ifdef CAFFE2_USE_MKLDNN #include @@ -7,7 +8,7 @@ namespace caffe2 { -template<> +template <> bool SquaredL2DistanceOp::RunOnDevice() { auto& X = Input(0); auto& Y = Input(1); @@ -257,7 +258,9 @@ OpSchema::Cost CostInferenceForDotProduct( CAFFE_ENFORCE_EQ(out[0].dims().size(), 1); struct OpSchema::Cost c = PointwiseCostInference<2>(def, in); - c.bytes_written = out[0].dims(0) * sizeof(out[0].data_type()); + auto const& out_0_element_size_byte = + DataTypeToTypeMeta(out[0].data_type()).itemsize(); + c.bytes_written = out[0].dims(0) * out_0_element_size_byte; c.params_bytes = 0; return c; } @@ -379,10 +382,12 @@ bool DotProductWithPaddingOp::RunOnDevice() { } // L2 -REGISTER_CPU_OPERATOR(SquaredL2Distance, - SquaredL2DistanceOp); -REGISTER_CPU_OPERATOR(SquaredL2DistanceGradient, - SquaredL2DistanceGradientOp); +REGISTER_CPU_OPERATOR( + SquaredL2Distance, + SquaredL2DistanceOp); +REGISTER_CPU_OPERATOR( + SquaredL2DistanceGradient, + SquaredL2DistanceGradientOp); OPERATOR_SCHEMA(SquaredL2Distance) .NumInputs(2) @@ -402,7 +407,8 @@ class GetSquaredL2DistanceGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; vector GetGradientDefs() override { return SingleGradientDef( - "SquaredL2DistanceGradient", "", + "SquaredL2DistanceGradient", + "", vector{I(0), I(1), GO(0)}, vector{GI(0), GI(1)}); } @@ -762,9 +768,9 @@ class GetDotProductWithPaddingGradient : public GradientMakerBase { replicate = GetArgument(Def(), "replicate").i(); } - const auto dot_arg = - vector{MakeArgument("pad_value", pad_value), - MakeArgument("replicate", replicate)}; + const auto dot_arg = vector{ + MakeArgument("pad_value", pad_value), + MakeArgument("replicate", replicate)}; return SingleGradientDef( "DotProductWithPaddingGradient", @@ -775,4 +781,4 @@ class GetDotProductWithPaddingGradient : public GradientMakerBase { } }; REGISTER_GRADIENT(DotProductWithPadding, GetDotProductWithPaddingGradient); -} // namespace caffe2 +} // namespace caffe2 diff --git a/caffe2/operators/fc_inference.cc b/caffe2/operators/fc_inference.cc index a44c230..ba1b712 100644 --- a/caffe2/operators/fc_inference.cc +++ b/caffe2/operators/fc_inference.cc @@ -1,4 +1,5 @@ #include "caffe2/operators/fc_inference.h" +#include "caffe2/core/types.h" namespace caffe2 { std::vector FCShapeInference( @@ -51,11 +52,12 @@ OpSchema::Cost CostInferenceForFC( ? size_from_dim_(canonical_axis_w, GetDimsVector(in[1])) : size_to_dim_(canonical_axis_w, GetDimsVector(in[1])); - const auto& X = in[0]; + auto const& X_element_size_byte = + DataTypeToTypeMeta(in[0].data_type()).itemsize(); c.flops = M * N * (2 * K + 1); - c.bytes_read = (K * (M + N) + N) * sizeof(X.data_type()); - c.bytes_written = M * N * sizeof(X.data_type()); - c.params_bytes = (K * N + N) * sizeof(X.data_type()); + c.bytes_read = (K * (M + N) + N) * X_element_size_byte; + c.bytes_written = M * N * X_element_size_byte; + c.params_bytes = (K * N + N) * X_element_size_byte; return c; } @@ -94,7 +96,11 @@ OpSchema::Cost CostInferenceForFCGradient( CAFFE_ENFORCE_LT(0, out.size()); const TensorShape dW = out[0]; + auto const& dW_element_size_byte = + DataTypeToTypeMeta(dW.data_type()).itemsize(); const TensorShape db = out[1]; + auto const& db_element_size_byte = + DataTypeToTypeMeta(db.data_type()).itemsize(); auto axis = helper.GetSingleArgument("axis", 1); const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size()); @@ -111,15 +117,17 @@ OpSchema::Cost CostInferenceForFCGradient( uint64_t size_db = nElemFromDim(db); c.flops = M * N * (2 * K + 1); - c.bytes_written = (size_dW + size_db) * sizeof(float); + c.bytes_written = + size_dW * dW_element_size_byte + size_db * db_element_size_byte; c.params_bytes = (K * N + N) * sizeof(float); if (out.size() == 3) { const TensorShape dX = out[2]; uint64_t size_dX = nElemFromDim(dX); - + auto const& dX_element_size_byte = + DataTypeToTypeMeta(dX.data_type()).itemsize(); c.flops += 2 * M * N * K; - c.bytes_written += size_dX * sizeof(float); + c.bytes_written += size_dX * dX_element_size_byte; } return c; } diff --git a/caffe2/operators/one_hot_ops.cc b/caffe2/operators/one_hot_ops.cc index c3eaf05..55c73a5 100644 --- a/caffe2/operators/one_hot_ops.cc +++ b/caffe2/operators/one_hot_ops.cc @@ -2,6 +2,7 @@ #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" +#include "caffe2/core/types.h" namespace caffe2 { @@ -78,12 +79,21 @@ OpSchema::Cost CostInferenceForBatchOneHot( const auto& length = in[1]; const auto& values = in[2]; - uint64_t nBytesData = nElemFromDim(data) * sizeof(data.data_type()); - uint64_t nBytesLength = nElemFromDim(length) * sizeof(length.data_type()); - uint64_t nBytesValues = nElemFromDim(values) * sizeof(values.data_type()); + auto const& data_element_size_byte = + DataTypeToTypeMeta(data.data_type()).itemsize(); + auto const& length_element_size_byte = + DataTypeToTypeMeta(length.data_type()).itemsize(); + auto const& values_element_size_byte = + DataTypeToTypeMeta(values.data_type()).itemsize(); + auto const& output_element_size_byte = + DataTypeToTypeMeta(output.data_type()).itemsize(); + + uint64_t nBytesData = nElemFromDim(data) * data_element_size_byte; + uint64_t nBytesLength = nElemFromDim(length) * length_element_size_byte; + uint64_t nBytesValues = nElemFromDim(values) * values_element_size_byte; c.flops = 0; c.bytes_read = nBytesData + nBytesLength + nBytesValues; - c.bytes_written = nElemFromDim(output) * sizeof(output.data_type()); + c.bytes_written = nElemFromDim(output) * output_element_size_byte; c.params_bytes = 0; return c; } @@ -145,15 +155,15 @@ bool BatchBucketOneHotOp::RunOnDevice() { for (int64_t j = 0; j < D; j++) { // here we assume the boundary values for each feature are sorted int64_t lower_bucket_idx = std::lower_bound( - boundaries_offset, - boundaries_offset + lens_data[j], - input_data[pos]) - + boundaries_offset, + boundaries_offset + lens_data[j], + input_data[pos]) - boundaries_offset; int64_t upper_bucket_idx = std::upper_bound( - boundaries_offset, - boundaries_offset + lens_data[j], - input_data[pos]) - + boundaries_offset, + boundaries_offset + lens_data[j], + input_data[pos]) - boundaries_offset; int64_t bucket_idx = (lower_bucket_idx + upper_bucket_idx) / 2; diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc index 8b5e116..561da91 100644 --- a/caffe2/operators/utility_ops.cc +++ b/caffe2/operators/utility_ops.cc @@ -1,6 +1,7 @@ #include "caffe2/operators/utility_ops.h" #include #include +#include "caffe2/core/types.h" #include "caffe2/utils/eigen_utils.h" namespace caffe2 { @@ -34,9 +35,11 @@ OpSchema::Cost CostInferenceForWeightedSum( const auto& nElem = nElemFromDim(X0); const auto& nInputs = in.size(); c.flops = (nInputs - 1) * nElem; - c.bytes_read = (nInputs / 2) * (nElem + 1) * sizeof(X0.data_type()); - c.bytes_written = nElem * sizeof(X0.data_type()); - c.params_bytes = (nInputs / 2) * sizeof(X0.data_type()); + auto const& X0_element_size_byte = + DataTypeToTypeMeta(X0.data_type()).itemsize(); + c.bytes_read = (nInputs / 2) * (nElem + 1) * X0_element_size_byte; + c.bytes_written = nElem * X0_element_size_byte; + c.params_bytes = (nInputs / 2) * X0_element_size_byte; return c; } @@ -48,9 +51,7 @@ REGISTER_CPU_OPERATOR(ResizeLike, ResizeLikeOp); REGISTER_CPU_OPERATOR(SumInt, SumOp); REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp); REGISTER_CPU_OPERATOR(WeightedSumGradient, WeightedSumGradientOp); -REGISTER_CPU_OPERATOR( - ScatterWeightedSum, - ScatterWeightedSumOp); +REGISTER_CPU_OPERATOR(ScatterWeightedSum, ScatterWeightedSumOp); REGISTER_CPU_OPERATOR(ScatterAssign, ScatterAssignOp); REGISTER_CPU_OPERATOR(Scatter, ScatterOp); diff --git a/caffe2/python/operator_test/concat_op_cost_test.py b/caffe2/python/operator_test/concat_op_cost_test.py index 996b330..7dab4d6 100644 --- a/caffe2/python/operator_test/concat_op_cost_test.py +++ b/caffe2/python/operator_test/concat_op_cost_test.py @@ -7,33 +7,39 @@ from caffe2.python.test_util import TestCase class TestConcatOpCost(TestCase): def test_columnwise_concat(self): - workspace.ResetWorkspace() - workspace.FeedBlob("input_1", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) - workspace.FeedBlob("input_2", np.array([[7], [8]], dtype=np.int32)) - concat_op = core.CreateOperator( - "Concat", - ["input_1", "input_2"], - ["output", "split_info"], - ) - workspace.RunOperatorOnce(concat_op) + def _test_columnwise_concat_for_type(dtype): + workspace.ResetWorkspace() + workspace.FeedBlob("input_1", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)) + workspace.FeedBlob("input_2", np.array([[7], [8]], dtype=dtype)) + concat_op = core.CreateOperator( + "Concat", + ["input_1", "input_2"], + ["output", "split_info"], + ) + workspace.RunOperatorOnce(concat_op) - output = workspace.FetchBlob("output") - self.assertTupleEqual(output.shape, (2, 4)) - np.testing.assert_array_equal(output, [[1, 2, 3, 7], [4, 5, 6, 8]]) + output = workspace.FetchBlob("output") + self.assertTupleEqual(output.shape, (2, 4)) + np.testing.assert_array_equal(output, [[1, 2, 3, 7], [4, 5, 6, 8]]) - flops, bytes_written, bytes_read = workspace.GetOperatorCost( - concat_op, concat_op.input - ) + flops, bytes_written, bytes_read = workspace.GetOperatorCost( + concat_op, concat_op.input + ) - self.assertEqual(flops, 0) - self.assertEqual( - bytes_read, - sum(workspace.FetchBlob(b).nbytes for b in concat_op.input), - ) - self.assertEqual( - bytes_written, - sum(workspace.FetchBlob(b).nbytes for b in concat_op.output), - ) + self.assertEqual(flops, 0) + self.assertEqual( + bytes_read, + sum(workspace.FetchBlob(b).nbytes for b in concat_op.input), + ) + self.assertEqual( + bytes_written, + sum(workspace.FetchBlob(b).nbytes for b in concat_op.output), + ) + + [ + _test_columnwise_concat_for_type(t) + for t in [np.int64, np.float, np.half, np.int8] + ] def test_split_then_concat(self): workspace.ResetWorkspace() diff --git a/caffe2/python/workspace_test.py b/caffe2/python/workspace_test.py index afb2065..1bf7b60 100644 --- a/caffe2/python/workspace_test.py +++ b/caffe2/python/workspace_test.py @@ -60,7 +60,7 @@ class TestWorkspace(unittest.TestCase): self.assertTupleEqual( op_cost, namedtuple("Cost", ["flops", "bytes_written", "bytes_read"])( - 1152, 256, 2084 + 1152, 256, 4168 ), ) diff --git a/caffe2/sgd/adagrad_op.cc b/caffe2/sgd/adagrad_op.cc index 0de50f0..0b6f604 100644 --- a/caffe2/sgd/adagrad_op.cc +++ b/caffe2/sgd/adagrad_op.cc @@ -1,4 +1,5 @@ #include "adagrad_op.h" +#include "caffe2/core/types.h" namespace caffe2 { @@ -23,22 +24,30 @@ static OpSchema::Cost CostInferenceForAdagrad( // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) c.flops = grad_size * 10; + auto const& moment_element_size_byte = + DataTypeToTypeMeta(moment.data_type()).itemsize(); + auto const& param_element_size_byte = + DataTypeToTypeMeta(param.data_type()).itemsize(); + auto const& grad_element_size_byte = + DataTypeToTypeMeta(grad.data_type()).itemsize(); + auto const& lr_element_size_byte = + DataTypeToTypeMeta(lr.data_type()).itemsize(); uint64_t bytes_written = - grad_size * (sizeof(param.data_type()) + sizeof(moment.data_type())); + grad_size * param_element_size_byte + moment_element_size_byte; if (output_size == 3) { // also need to output effective learning rate in this case // assume it's the same data type as lr - bytes_written += grad_size * sizeof(lr.data_type()); + bytes_written += grad_size * lr_element_size_byte; } else if (output_size == 4) { // also need to output effective learning rate and updates in this case // assume update is the same data type as param bytes_written += - grad_size * (sizeof(lr.data_type()) + sizeof(param.data_type())); + grad_size * (lr_element_size_byte + param_element_size_byte); } c.bytes_written = bytes_written; c.bytes_read = c.bytes_written + - grad_size * (sizeof(grad.data_type()) + sizeof(lr.data_type())); + grad_size * (grad_element_size_byte + lr_element_size_byte); return c; } @@ -102,10 +111,18 @@ static OpSchema::Cost CostInferenceForSparseAdagrad( // (optimistically count sqrt as one flop). // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) c.flops = grad_size * 7; + auto const& param_element_size_byte = + DataTypeToTypeMeta(param.data_type()).itemsize(); + auto const& moment_element_size_byte = + DataTypeToTypeMeta(moment.data_type()).itemsize(); c.bytes_written = - grad_size * (sizeof(param.data_type()) + sizeof(moment.data_type())); - c.bytes_read = c.bytes_written + grad_size * sizeof(grad.data_type()) + - n * sizeof(indices.data_type()); + grad_size * (param_element_size_byte + moment_element_size_byte); + auto const& grad_element_size_byte = + DataTypeToTypeMeta(grad.data_type()).itemsize(); + auto const& indices_element_size_byte = + DataTypeToTypeMeta(indices.data_type()).itemsize(); + c.bytes_read = c.bytes_written + grad_size * grad_element_size_byte + + n * indices_element_size_byte; return c; } @@ -153,6 +170,16 @@ static OpSchema::Cost CostInferenceForRowWiseSparseAdagrad( OpSchema::Cost c; if (n > 0) { + auto const& param_element_size_byte = + DataTypeToTypeMeta(param.data_type()).itemsize(); + auto const& moment_element_size_byte = + DataTypeToTypeMeta(moment.data_type()).itemsize(); + auto const& grad_element_size_byte = + DataTypeToTypeMeta(grad.data_type()).itemsize(); + auto const& indices_element_size_byte = + DataTypeToTypeMeta(indices.data_type()).itemsize(); + auto const& lr_element_size_byte = + DataTypeToTypeMeta(lr.data_type()).itemsize(); auto block_size = grad_size / n; if (block_size == 1) { // +2: applying weight decay and add to grads @@ -161,22 +188,22 @@ static OpSchema::Cost CostInferenceForRowWiseSparseAdagrad( // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) c.flops = n * 9; c.bytes_written = - n * (sizeof(param.data_type()) + sizeof(moment.data_type())); + n * (param_element_size_byte + moment_element_size_byte); c.bytes_read = c.bytes_written + n * - (sizeof(grad.data_type()) + sizeof(indices.data_type()) + - sizeof(lr.data_type())); + (grad_element_size_byte + indices_element_size_byte + + lr_element_size_byte); } else { // 5 per block (not counting index transforms) // 8 for each value of a block // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) c.flops = n * (5 + (block_size * 8)); - c.bytes_written = - n * sizeof(moment.data_type()) + n * block_size * (param.data_type()); + c.bytes_written = n * moment_element_size_byte + + n * block_size * param_element_size_byte; - c.bytes_read = c.bytes_written + n * (sizeof(lr.data_type())) + + c.bytes_read = c.bytes_written + n * lr_element_size_byte + 2 * n * block_size * - (sizeof(grad.data_type()) + sizeof(param.data_type())); + (grad_element_size_byte + param_element_size_byte); } } return c; -- 2.7.4