add tensor and cost inference functions (#17684)
authorJongsoo Park <jongsoo@fb.com>
Thu, 7 Mar 2019 07:26:27 +0000 (23:26 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 7 Mar 2019 07:34:02 +0000 (23:34 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17684

Adding tensor and cost inference functions to more int8 operators.

Reviewed By: yinghai

Differential Revision: D14174746

fbshipit-source-id: dfad975fa75899565c8fb61f1b7747a9206ebd22

caffe2/operators/concat_split_op.cc
caffe2/operators/concat_split_op.h
caffe2/operators/elementwise_sum_op.cc
caffe2/operators/flatten_op.cc
caffe2/operators/flatten_op.h
caffe2/operators/quantized/int8_add_op.cc
caffe2/operators/quantized/int8_concat_op.cc
caffe2/operators/quantized/int8_fc_op.cc
caffe2/operators/quantized/int8_flatten_op.cc
caffe2/operators/quantized/int8_given_tensor_fill_op.cc
caffe2/operators/utility_ops.h

index 57ef45c..ff66578 100644 (file)
@@ -106,7 +106,6 @@ Split a tensor into a list of tensors, given a lengths input, along the specifie
 The `input` will be split into `K` parts. Each part of length
 `sum(lengths[i*k:i*k+k))`)DOC");
 
-namespace {
 OpSchema::Cost CostInferenceForConcat(
     const OperatorDef& def,
     const vector<TensorShape>& in) {
@@ -143,6 +142,7 @@ OpSchema::Cost CostInferenceForConcat(
   return cost;
 }
 
+namespace {
 std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>>
 concatOpDevInfer(const OperatorDef& def) {
   auto op_device =
@@ -157,6 +157,80 @@ concatOpDevInfer(const OperatorDef& def) {
 }
 } // namespace
 
+vector<TensorShape> TensorInferenceForConcat(
+    const OperatorDef& def,
+    const vector<TensorShape>& in) {
+  ArgumentHelper helper(def);
+  const int axis = helper.HasArgument("axis")
+      ? helper.GetSingleArgument<int>("axis", -1)
+      : GetDimFromOrderString(
+            helper.GetSingleArgument<string>("order", "NCHW"));
+  bool add_axis = helper.GetSingleArgument<int>("add_axis", 0) != 0;
+  int adj_size = in[0].dims_size() + (add_axis ? 1 : 0);
+  const int canonical_axis = canonical_axis_index_(axis, adj_size);
+  CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
+  CAFFE_ENFORCE_GT(in.size(), 0);
+  vector<int> split_shape(1, in.size());
+  vector<int> out_shape(in[0].dims().begin(), in[0].dims().end());
+  if (add_axis) {
+    for (int i = 1; i < in.size(); ++i) {
+      CAFFE_ENFORCE_EQ(
+          in[0].dims().size(),
+          in[i].dims().size(),
+          "All inputs of Concat should have same dims when add_axis = 1. "
+          "Got different sizes for inputs 0 and ",
+          i);
+      for (int j = 0; j < in[0].dims().size(); ++j) {
+        CAFFE_ENFORCE_EQ(
+            in[0].dims(j),
+            in[i].dims(j),
+            "All inputs of Concat should have same dims when add_axis = 1. "
+            "Got different dims for inputs 0 and ",
+            i,
+            ". At dim: ",
+            j);
+      }
+    }
+    out_shape.insert(out_shape.begin() + canonical_axis, in.size());
+  } else {
+    for (int i = 1; i < in.size(); ++i) {
+      CAFFE_ENFORCE_EQ(
+          in[0].dims().size(),
+          in[i].dims().size(),
+          "All inputs of Concat should have same dims except "
+          "canonical_axis dim that is equal to ",
+          canonical_axis,
+          "Got different sizes for inputs 0 and ",
+          i);
+      for (int j = 0; j < in[0].dims().size(); ++j) {
+        if (j == canonical_axis) {
+          continue;
+        }
+        CAFFE_ENFORCE_EQ(
+            in[0].dims(j),
+            in[i].dims(j),
+            "All inputs of Concat should have same dims except "
+            "canonical_axis dim that is equal to ",
+            canonical_axis,
+            "Got different dims for inputs 0 and ",
+            i,
+            ". At dim: ",
+            j);
+      }
+    }
+
+    for (int i = 1; i < in.size(); ++i) {
+      out_shape[canonical_axis] += in[i].dims(canonical_axis);
+    }
+  }
+  if (def.output_size() == 1) {
+    return vector<TensorShape>{CreateTensorShape(out_shape, in[0].data_type())};
+  }
+  return vector<TensorShape>{
+      CreateTensorShape(out_shape, in[0].data_type()),
+      CreateTensorShape(split_shape, TensorProto::INT32)};
+}
+
 REGISTER_CPU_OPERATOR(Concat, ConcatOp<CPUContext>);
 OPERATOR_SCHEMA(Concat)
     .NumInputs(1, INT_MAX)
@@ -168,83 +242,8 @@ OPERATOR_SCHEMA(Concat)
     .Arg(
         "add_axis",
         "*(type: int)* Pass non-zero integer to add the axis specified in `axis` to all input tensors.")
-    .TensorInferenceFunction(OpSchema::NeedsAllInputShapes([](const OperatorDef&
-                                                                  def,
-                                                              const vector<
-                                                                  TensorShape>&
-                                                                  in) {
-      ArgumentHelper helper(def);
-      const int axis = helper.HasArgument("axis")
-          ? helper.GetSingleArgument<int>("axis", -1)
-          : GetDimFromOrderString(
-                helper.GetSingleArgument<string>("order", "NCHW"));
-      bool add_axis = helper.GetSingleArgument<int>("add_axis", 0) != 0;
-      int adj_size = in[0].dims_size() + (add_axis ? 1 : 0);
-      const int canonical_axis = canonical_axis_index_(axis, adj_size);
-      CAFFE_ENFORCE_LT(
-          canonical_axis, adj_size, "Axis not in input ndim range.");
-      CAFFE_ENFORCE_GT(in.size(), 0);
-      vector<int> split_shape(1, in.size());
-      vector<int> out_shape(in[0].dims().begin(), in[0].dims().end());
-      if (add_axis) {
-        for (int i = 1; i < in.size(); ++i) {
-          CAFFE_ENFORCE_EQ(
-              in[0].dims().size(),
-              in[i].dims().size(),
-              "All inputs of Concat should have same dims when add_axis = 1. "
-              "Got different sizes for inputs 0 and ",
-              i);
-          for (int j = 0; j < in[0].dims().size(); ++j) {
-            CAFFE_ENFORCE_EQ(
-                in[0].dims(j),
-                in[i].dims(j),
-                "All inputs of Concat should have same dims when add_axis = 1. "
-                "Got different dims for inputs 0 and ",
-                i,
-                ". At dim: ",
-                j);
-          }
-        }
-        out_shape.insert(out_shape.begin() + canonical_axis, in.size());
-      } else {
-        for (int i = 1; i < in.size(); ++i) {
-          CAFFE_ENFORCE_EQ(
-              in[0].dims().size(),
-              in[i].dims().size(),
-              "All inputs of Concat should have same dims except "
-              "canonical_axis dim that is equal to ",
-              canonical_axis,
-              "Got different sizes for inputs 0 and ",
-              i);
-          for (int j = 0; j < in[0].dims().size(); ++j) {
-            if (j == canonical_axis) {
-              continue;
-            }
-            CAFFE_ENFORCE_EQ(
-                in[0].dims(j),
-                in[i].dims(j),
-                "All inputs of Concat should have same dims except "
-                "canonical_axis dim that is equal to ",
-                canonical_axis,
-                "Got different dims for inputs 0 and ",
-                i,
-                ". At dim: ",
-                j);
-          }
-        }
-
-        for (int i = 1; i < in.size(); ++i) {
-          out_shape[canonical_axis] += in[i].dims(canonical_axis);
-        }
-      }
-      if (def.output_size() == 1) {
-        return vector<TensorShape>{
-            CreateTensorShape(out_shape, in[0].data_type())};
-      }
-      return vector<TensorShape>{
-          CreateTensorShape(out_shape, in[0].data_type()),
-          CreateTensorShape(split_shape, TensorProto::INT32)};
-    }))
+    .TensorInferenceFunction(
+        OpSchema::NeedsAllInputShapes(TensorInferenceForConcat))
     .CostInferenceFunction(CostInferenceForConcat)
     .DeviceInferenceFunction(concatOpDevInfer)
     .SetDoc(R"DOC(
index 53821ed..47ed663 100644 (file)
@@ -335,6 +335,14 @@ bool ConcatOp<Context>::RunOnDevice() {
   return true;
 }
 
+OpSchema::Cost CostInferenceForConcat(
+    const OperatorDef& def,
+    const std::vector<TensorShape>& in);
+
+std::vector<TensorShape> TensorInferenceForConcat(
+    const OperatorDef& def,
+    const std::vector<TensorShape>& in);
+
 } // namespace caffe2
 
 #endif // CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
index 2668f24..8371541 100644 (file)
@@ -2,17 +2,6 @@
 
 namespace caffe2 {
 
-namespace {
-OpSchema::Cost CostInferenceForSum(
-    const OperatorDef& def,
-    const vector<TensorShape>& in) {
-  struct OpSchema::Cost cost = PointwiseCostInference<1>(def, in);
-  cost.flops *= (in.size() - 1);
-  cost.params_bytes = 0;
-  return cost;
-}
-} // namespace
-
 REGISTER_CPU_OPERATOR(Sum, SumOp<CPUContext>);
 
 OPERATOR_SCHEMA(Sum)
index 3086b9a..a11867d 100644 (file)
@@ -7,27 +7,7 @@ REGISTER_CPU_OPERATOR(Flatten, FlattenOp<CPUContext>);
 OPERATOR_SCHEMA(Flatten)
     .NumInputs(1)
     .NumOutputs(1)
-    .TensorInferenceFunction([](const OperatorDef& def,
-                                const vector<TensorShape>& in) {
-      ArgumentHelper helper(def);
-      const int axis = helper.GetSingleArgument<int>("axis", 1);
-      vector<TensorShape> out(1);
-      int64_t outer = 1;
-      int64_t inner = 1;
-      std::size_t index = 0;
-      for (auto d : in[0].dims()) {
-        if (index < axis) {
-          outer *= d;
-        } else {
-          inner *= d;
-        }
-        ++index;
-      }
-      out[0].set_data_type(in[0].data_type());
-      out[0].add_dims(outer);
-      out[0].add_dims(inner);
-      return out;
-    })
+    .TensorInferenceFunction(TensorInferenceForFlatten)
     .SetDoc(R"DOC(
 Flattens the input tensor into a 2D matrix. If input tensor has shape
 $(d_0, d_1, ..., d_n)$ then the output will have shape
index cdc8d83..401e6fb 100644 (file)
@@ -33,6 +33,29 @@ class FlattenOp : public Operator<Context> {
   int axis_;
 };
 
+inline std::vector<TensorShape> TensorInferenceForFlatten(
+    const OperatorDef& def,
+    const std::vector<TensorShape>& in) {
+  ArgumentHelper helper(def);
+  const int axis = helper.GetSingleArgument<int>("axis", 1);
+  std::vector<TensorShape> out(1);
+  int64_t outer = 1;
+  int64_t inner = 1;
+  std::size_t index = 0;
+  for (auto d : in[0].dims()) {
+    if (index < axis) {
+      outer *= d;
+    } else {
+      inner *= d;
+    }
+    ++index;
+  }
+  out[0].set_data_type(in[0].data_type());
+  out[0].add_dims(outer);
+  out[0].add_dims(inner);
+  return out;
+}
+
 } // namespace caffe2
 
 #endif // CAFFE2_OPERATORS_FLATTEN_OP_H_
index 225a7a1..b1b7be2 100644 (file)
@@ -1,6 +1,8 @@
+#include "caffe2/operators/quantized/int8_add_op.h"
+
 #include <climits>
 
-#include "caffe2/operators/quantized/int8_add_op.h"
+#include "caffe2/operators/utility_ops.h"
 
 namespace caffe2 {
 
@@ -55,6 +57,8 @@ OPERATOR_SCHEMA(Int8Sum)
     .NumInputs(1, std::numeric_limits<int>::max())
     .NumOutputs(1)
     .AllowInplace({{0, 0}, {1, 0}})
+    .CostInferenceFunction(CostInferenceForSum)
+    .IdenticalTypeAndShapeOfInput(0)
     .Arg("Y_scale", "Output tensor quantization scale")
     .Arg("Y_zero_point", "Output tensor quantization offset");
 
@@ -62,6 +66,8 @@ OPERATOR_SCHEMA(Int8SumRelu)
     .NumInputs(1, std::numeric_limits<int>::max())
     .NumOutputs(1)
     .AllowInplace({{0, 0}, {1, 0}})
+    .CostInferenceFunction(CostInferenceForSum)
+    .IdenticalTypeAndShapeOfInput(0)
     .Arg("Y_scale", "Output tensor quantization scale")
     .Arg("Y_zero_point", "Output tensor quantization offset");
 
index 8950d41..ae601a3 100644 (file)
@@ -1,5 +1,7 @@
 #include "caffe2/operators/quantized/int8_concat_op.h"
 
+#include "caffe2/operators/concat_split_op.h"
+
 namespace caffe2 {
 
 REGISTER_CPU_OPERATOR(Int8Concat, int8::Int8ConcatOp);
@@ -14,6 +16,9 @@ OPERATOR_SCHEMA(Int8Concat)
         "add_axis",
         "Pass 1 to add the axis specified in arg 'axis' to all "
         "input tensors")
+    .TensorInferenceFunction(
+        OpSchema::NeedsAllInputShapes(TensorInferenceForConcat))
+    .CostInferenceFunction(CostInferenceForConcat)
     .SetDoc("Concatenate a list of tensors into a single tensor")
     .Output(0, "concat_result", "Concatenated tensor")
     .Output(1, "split_info", "The dimensions of the inputs.")
index ee7d605..0a8e504 100644 (file)
@@ -1,12 +1,19 @@
 #include "caffe2/operators/quantized/int8_fc_op.h"
 
+#include <functional>
+
+#include "caffe2/operators/fc_inference.h"
+
 namespace caffe2 {
 
 REGISTER_CPU_OPERATOR(Int8FC, int8::Int8FCOp);
 
+using namespace std::placeholders;
 OPERATOR_SCHEMA(Int8FC)
     .NumInputs(3)
     .NumOutputs(1)
+    .TensorInferenceFunction(std::bind(FCShapeInference, _1, _2, false))
+    .CostInferenceFunction(std::bind(CostInferenceForFC, _1, _2, false))
     .SetDoc(R"DOC(
 Computes the result of passing an input vector X into a fully
 connected layer with 2D weight matrix W and 1D bias vector b. That is,
index a9d5a5e..14e1381 100644 (file)
@@ -1,5 +1,7 @@
 #include "caffe2/operators/quantized/int8_flatten_op.h"
 
+#include "caffe2/operators/flatten_op.h"
+
 namespace caffe2 {
 
 REGISTER_CPU_OPERATOR(Int8Flatten, int8::Int8FlattenOp);
@@ -7,6 +9,7 @@ REGISTER_CPU_OPERATOR(Int8Flatten, int8::Int8FlattenOp);
 OPERATOR_SCHEMA(Int8Flatten)
     .NumInputs(1)
     .NumOutputs(1)
+    .TensorInferenceFunction(TensorInferenceForFlatten)
     .SetDoc(R"DOC(
 Flattens the input tensor into a 2D matrix. If input tensor has shape
 (d_0, d_1, ... d_n) then the output will have shape
index dc108d1..709c222 100644 (file)
@@ -12,7 +12,8 @@ OPERATOR_SCHEMA(Int8GivenTensorFill)
     .SetDoc(R"DOC(
     Creates quantized tensor of type char(byte) with scale and zero point info.
 )DOC")
-    .Output(0, "Tensor", "An Int8TensorCPU with scale and zero point info");
+    .Output(0, "Tensor", "An Int8TensorCPU with scale and zero point info")
+    .TensorInferenceFunction(FillerTensorInference<>);
 
 OPERATOR_SCHEMA(Int8GivenIntTensorFill)
     .NumInputs(0)
@@ -24,7 +25,8 @@ OPERATOR_SCHEMA(Int8GivenIntTensorFill)
     .SetDoc(R"DOC(
     Creates quantized tensor of type int32 with scale and zero point info.
 )DOC")
-    .Output(0, "Tensor", "An Int8TensorCPU with scale and zero point info");
+    .Output(0, "Tensor", "An Int8TensorCPU with scale and zero point info")
+    .TensorInferenceFunction(FillerTensorInference<>);
 
 REGISTER_CPU_OPERATOR(Int8GivenTensorFill, int8::Int8GivenTensorFillOp);
 REGISTER_CPU_OPERATOR(Int8GivenIntTensorFill, int8::Int8GivenIntTensorFillOp);
index a77ee9a..7b92b2c 100644 (file)
@@ -317,6 +317,15 @@ class SumOp : public Operator<Context> {
   }
 };
 
+inline OpSchema::Cost CostInferenceForSum(
+    const OperatorDef& def,
+    const std::vector<TensorShape>& in) {
+  struct OpSchema::Cost cost = PointwiseCostInference<1>(def, in);
+  cost.flops *= (in.size() - 1);
+  cost.params_bytes = 0;
+  return cost;
+}
+
 // WeightedSumOp computes the weighted sum of several tensors. The input should
 // be in the form X_0, weight_0, X_1, weight_1, ... where X_i all have the same
 // shape, and weight_i are size 1 tensors that specifies the weight of each