Remove Context from c10 operator schemas (#15312)
authorSebastian Messmer <messmer@fb.com>
Fri, 11 Jan 2019 00:06:26 +0000 (16:06 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 11 Jan 2019 00:22:20 +0000 (16:22 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15312

Context will soon be entirely obsolete. Remove it from the operator schema interface.

Reviewed By: dzhulgakov

Differential Revision: D13495323

fbshipit-source-id: caa0f8f092cd6284e510c3e1e3374fe2f8338364

25 files changed:
c10/core/opschema/layer_norm.h
caffe2/core/operator_c10wrapper.h
caffe2/operators/experimental/c10/cpu/add_cpu.cc
caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc
caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc
caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc
caffe2/operators/experimental/c10/cpu/concat_cpu.cc
caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc
caffe2/operators/experimental/c10/cpu/fc_cpu.cc
caffe2/operators/experimental/c10/cpu/filler_cpu.cc
caffe2/operators/experimental/c10/cpu/flatten_cpu.cc
caffe2/operators/experimental/c10/cpu/mul_cpu.cc
caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc
caffe2/operators/experimental/c10/schemas/add.h
caffe2/operators/experimental/c10/schemas/averaged_loss.h
caffe2/operators/experimental/c10/schemas/batch_gather.h
caffe2/operators/experimental/c10/schemas/batch_matmul.h
caffe2/operators/experimental/c10/schemas/concat.h
caffe2/operators/experimental/c10/schemas/expand_dims.h
caffe2/operators/experimental/c10/schemas/fc.h
caffe2/operators/experimental/c10/schemas/filler.h
caffe2/operators/experimental/c10/schemas/flatten.h
caffe2/operators/experimental/c10/schemas/mul.h
caffe2/operators/experimental/c10/schemas/stop_gradient.h
torch/csrc/jit/c10_ops/layer_norm.cpp

index a00ded0..d80c965 100644 (file)
@@ -3,10 +3,6 @@
 #include <c10/core/Tensor.h>
 #include <c10/util/Array.h>
 
-namespace at {
-class BaseContext;
-}
-
 namespace c10 {
 namespace core {
 namespace opschema {
index 88f3536..35df591 100644 (file)
@@ -55,10 +55,6 @@ class C10OperatorWrapper final : public Operator<Context> {
 
   USE_OPERATOR_CONTEXT_FUNCTIONS;
 
-  static constexpr bool op_has_context_argument = std::is_same<
-      BaseContext*,
-      c10::guts::typelist::last_t<
-          typename Schema::signature::parameter_types>>::value;
   static constexpr bool op_has_state_argument =
       !std::is_same<void, State>::value;
 
@@ -71,7 +67,7 @@ class C10OperatorWrapper final : public Operator<Context> {
 
   static constexpr size_t num_inputs() {
     return Schema::signature::num_args - num_outputs() - num_parameters() -
-        (op_has_context_argument ? 1 : 0) - (op_has_state_argument ? 1 : 0);
+        (op_has_state_argument ? 1 : 0);
   }
 
   static constexpr size_t num_parameters() {
@@ -112,46 +108,7 @@ class C10OperatorWrapper final : public Operator<Context> {
       size_t... OutputIndex,
       size_t... ParameterIndex>
   c10::guts::enable_if_t<
-      details::true_t<InputIndex...>::value && op_has_context_argument &&
-          op_has_state_argument && !use_array_input,
-      void>
-  RunOnDevice_(
-      c10::guts::index_sequence<InputIndex...>,
-      c10::guts::index_sequence<OutputIndex...>,
-      c10::guts::index_sequence<ParameterIndex...>) {
-    c10::Dispatcher<OpSchemaDef>::call(
-        C10Tensor(Input(InputIndex))...,
-        C10Tensor(*Output(OutputIndex))...,
-        std::get<ParameterIndex>(parameters_)...,
-        state_.get(),
-        static_cast<BaseContext*>(&context_));
-  }
-
-  template <
-      size_t... InputIndex,
-      size_t... OutputIndex,
-      size_t... ParameterIndex>
-  c10::guts::enable_if_t<
-      details::true_t<InputIndex...>::value && op_has_context_argument &&
-          !op_has_state_argument && !use_array_input,
-      void>
-  RunOnDevice_(
-      c10::guts::index_sequence<InputIndex...>,
-      c10::guts::index_sequence<OutputIndex...>,
-      c10::guts::index_sequence<ParameterIndex...>) {
-    c10::Dispatcher<OpSchemaDef>::call(
-        C10Tensor(Input(InputIndex))...,
-        C10Tensor(*Output(OutputIndex))...,
-        std::get<ParameterIndex>(parameters_)...,
-        static_cast<BaseContext*>(&context_));
-  }
-
-  template <
-      size_t... InputIndex,
-      size_t... OutputIndex,
-      size_t... ParameterIndex>
-  c10::guts::enable_if_t<
-      details::true_t<InputIndex...>::value && !op_has_context_argument &&
+      details::true_t<InputIndex...>::value &&
           op_has_state_argument && !use_array_input,
       void>
   RunOnDevice_(
@@ -170,7 +127,7 @@ class C10OperatorWrapper final : public Operator<Context> {
       size_t... OutputIndex,
       size_t... ParameterIndex>
   c10::guts::enable_if_t<
-      details::true_t<InputIndex...>::value && !op_has_context_argument &&
+      details::true_t<InputIndex...>::value &&
           !op_has_state_argument && !use_array_input,
       void>
   RunOnDevice_(
@@ -188,46 +145,7 @@ class C10OperatorWrapper final : public Operator<Context> {
       size_t... OutputIndex,
       size_t... ParameterIndex>
   c10::guts::enable_if_t<
-      details::true_t<InputIndex...>::value && op_has_context_argument &&
-          op_has_state_argument && use_array_input,
-      void>
-  RunOnDevice_(
-      c10::guts::index_sequence<InputIndex...>,
-      c10::guts::index_sequence<OutputIndex...>,
-      c10::guts::index_sequence<ParameterIndex...>) {
-    c10::Dispatcher<OpSchemaDef>::call(
-        at::ArrayRef<C10Tensor>(array_inputs_()),
-        C10Tensor(*Output(OutputIndex))...,
-        std::get<ParameterIndex>(parameters_)...,
-        state_.get(),
-        static_cast<BaseContext*>(&context_));
-  }
-
-  template <
-      size_t... InputIndex,
-      size_t... OutputIndex,
-      size_t... ParameterIndex>
-  c10::guts::enable_if_t<
-      details::true_t<InputIndex...>::value && op_has_context_argument &&
-          !op_has_state_argument && use_array_input,
-      void>
-  RunOnDevice_(
-      c10::guts::index_sequence<InputIndex...>,
-      c10::guts::index_sequence<OutputIndex...>,
-      c10::guts::index_sequence<ParameterIndex...>) {
-    c10::Dispatcher<OpSchemaDef>::call(
-        at::ArrayRef<C10Tensor>(array_inputs_()),
-        C10Tensor(*Output(OutputIndex))...,
-        std::get<ParameterIndex>(parameters_)...,
-        static_cast<BaseContext*>(&context_));
-  }
-
-  template <
-      size_t... InputIndex,
-      size_t... OutputIndex,
-      size_t... ParameterIndex>
-  c10::guts::enable_if_t<
-      details::true_t<InputIndex...>::value && !op_has_context_argument &&
+      details::true_t<InputIndex...>::value &&
           op_has_state_argument && use_array_input,
       void>
   RunOnDevice_(
@@ -246,7 +164,7 @@ class C10OperatorWrapper final : public Operator<Context> {
       size_t... OutputIndex,
       size_t... ParameterIndex>
   c10::guts::enable_if_t<
-      details::true_t<InputIndex...>::value && !op_has_context_argument &&
+      details::true_t<InputIndex...>::value &&
           !op_has_state_argument && use_array_input,
       void>
   RunOnDevice_(
index b55357b..34d4a51 100644 (file)
@@ -15,11 +15,11 @@ void add_op_cpu_impl(
     const C10Tensor& B_,
     const C10Tensor& C_,
     bool legacy_broadcast,
-    int axis,
-    BaseContext* context) {
+    int axis) {
   Tensor A(A_);
   Tensor B(B_);
   Tensor C(C_);
+  CPUContext context;
   const DataType* A_data = A.template data<DataType>();
   const DataType* B_data = B.template data<DataType>();
   std::vector<int> A_dims;
@@ -68,7 +68,7 @@ void add_op_cpu_impl(
       A.data<DataType>(),
       B.data<DataType>(),
       C.mutable_data<DataType>(),
-      static_cast<CPUContext*>(context));
+      static_cast<CPUContext*>(&context));
 }
 } // namespace
 } // namespace caffe2
index 276eb75..6e4f2e9 100644 (file)
@@ -14,10 +14,10 @@ template <class T, class Context>
 void averaged_loss_op_cpu_impl(
     const C10Tensor& X_,
     const C10Tensor& sum_,
-    caffe2::ops::AveragedLoss::State* state,
-    BaseContext* context) {
+    caffe2::ops::AveragedLoss::State* state) {
   Tensor X(X_);
   Tensor sum(sum_);
+  CPUContext context;
 
   sum.Resize(vector<int64_t>());
 
@@ -28,7 +28,7 @@ void averaged_loss_op_cpu_impl(
       X.numel(),
       X.template data<T>(),
       data,
-      static_cast<Context*>(context),
+      static_cast<Context*>(&context),
       &scratch);
   if (X.numel() > 0) {
     caffe2::math::Scale<T, T, Context>(
@@ -36,7 +36,7 @@ void averaged_loss_op_cpu_impl(
         static_cast<T>(1.) / X.numel(),
         sum.template data<T>(),
         data,
-        static_cast<Context*>(context));
+        static_cast<Context*>(&context));
   }
 }
 } // namespace
index 01caa33..483a810 100644 (file)
@@ -14,11 +14,11 @@ template <class TInd>
 void batch_gather_op_cpu_impl(
     const C10Tensor& data_,
     const C10Tensor& indices_,
-    const C10Tensor& output_,
-    BaseContext* context) {
+    const C10Tensor& output_) {
   Tensor data(data_);
   Tensor indices(indices_);
   Tensor output(output_);
+  CPUContext context;
 
   CAFFE_ENFORCE_GE(data.dim(), 2, "DATA should be at least 2-D");
 
@@ -49,7 +49,7 @@ void batch_gather_op_cpu_impl(
           data.size(1));
       auto src = src_base + idx * block_bytesize + batch * data_batch_bytesize;
       auto dst = out + i * block_bytesize + batch * gathered_batch_bytesize;
-      context->CopyItemsSameDevice(data.dtype(), block_size, src, dst);
+      context.CopyItemsSameDevice(data.dtype(), block_size, src, dst);
     }
   }
 }
index 476cf1e..b3b93bb 100644 (file)
@@ -19,11 +19,11 @@ void batch_matmul_op_cpu_impl(
     int trans_a,
     int trans_b,
     int broadcast,
-    caffe2::ops::BatchMatmul::State* state,
-    BaseContext* context) {
+    caffe2::ops::BatchMatmul::State* state) {
   Tensor A(A_);
   Tensor B(B_);
   Tensor Y(Y_);
+  CPUContext context;
   using Engine = caffe2::DefaultEngine;
 
   auto ndims_A = A.dim();
@@ -83,7 +83,7 @@ void batch_matmul_op_cpu_impl(
         "be the same size.");
     Y.Resize(1);
     math::Dot<T, Context>(
-        dims_A[0], data_A, data_B, Y.template mutable_data<T>(), static_cast<Context*>(context));
+        dims_A[0], data_A, data_B, Y.template mutable_data<T>(), static_cast<Context*>(&context));
   } else {
     bool A_broadcasted = false, B_broadcasted = false;
     if (ndims_A == 1) {
@@ -260,7 +260,7 @@ void batch_matmul_op_cpu_impl(
           0.0f,
           Y_data + p * Y_stride,
           M * N,
-          static_cast<Context*>(context));
+          static_cast<Context*>(&context));
     }
   }
 }
index 9179cc4..1743bd2 100644 (file)
@@ -17,10 +17,10 @@ void concat_op_cpu_impl(
     const C10Tensor& output_,
     const C10Tensor& split_,
     int axis,
-    int add_axis,
-    BaseContext* context) {
+    int add_axis) {
   Tensor output(output_);
   Tensor split(split_);
+  CPUContext context;
 
   split.Resize(vector<int64_t>(1, inputs.size()));
   int* axis_data = split.template mutable_data<int>();
@@ -98,7 +98,7 @@ void concat_op_cpu_impl(
         static_cast<char*>(output.raw_mutable_data(Tensor(inputs[0]).dtype())) +
             output_offset,
         output_channels * after,
-        static_cast<Context*>(context),
+        static_cast<Context*>(&context),
         Tensor(inputs[0]).dtype().copy());
     output_offset += axis_dim * after * input.itemsize();
   }
index 5350ac7..8a6bd58 100644 (file)
@@ -3,7 +3,6 @@
 #include "caffe2/utils/math.h"
 #include "caffe2/core/tensor.h"
 
-using caffe2::BaseContext;
 using caffe2::Tensor;
 
 namespace caffe2 {
@@ -13,8 +12,7 @@ void expand_dims_op_cpu_impl(
     const C10Tensor& input_,
     const C10Tensor& output_,
     const std::vector<int>& dims,
-    caffe2::ops::ExpandDims::State* state,
-    BaseContext* context) {
+    caffe2::ops::ExpandDims::State* state) {
   Tensor input(input_);
   Tensor output(output_);
 
@@ -33,7 +31,7 @@ void expand_dims_op_cpu_impl(
     state->initialized = true;
   }
 
-  output.CopyFrom(input, context);
+  output.CopyFrom(input);
   if (state->dims.empty()) {
     return;
   }
index 3deb2a9..4498f38 100644 (file)
@@ -19,12 +19,12 @@ void fc_op_cpu_impl(
     const C10Tensor& Y_,
     int axis,
     int axis_w,
-    caffe2::ops::FullyConnected::Cache* cache,
-    BaseContext* context) {
+    caffe2::ops::FullyConnected::Cache* cache) {
   Tensor X(X_);
   Tensor W(W_);
   Tensor b(b_);
   Tensor Y(Y_);
+  CPUContext context;
 
   constexpr bool TransposeWeight = true;
 
@@ -94,7 +94,7 @@ void fc_op_cpu_impl(
       W.template data<DataType>(),
       0,
       Y.template mutable_data<DataType>(),
-      static_cast<Context*>(context),
+      static_cast<Context*>(&context),
       math_type);
   // Add bias term
   Tensor bias_multiplier(cache->bias_multiplier_);
@@ -105,7 +105,7 @@ void fc_op_cpu_impl(
         M,
         caffe2::convert::To<float, DataType>(1),
         bias_multiplier.template mutable_data<DataType>(),
-        static_cast<Context*>(context));
+        static_cast<Context*>(&context));
   }
   caffe2::math::Gemm<DataType, Context, caffe2::DefaultEngine>(
       CblasNoTrans,
@@ -118,7 +118,7 @@ void fc_op_cpu_impl(
       b.template data<DataType>(),
       1,
       Y.template mutable_data<DataType>(),
-      static_cast<Context*>(context),
+      static_cast<Context*>(&context),
       math_type);
 }
 } // namespace
index 6e2c223..77277ce 100644 (file)
@@ -49,10 +49,10 @@ void given_tensor_fill_op_cpu_impl(
     const std::vector<int64_t>& shape,
     const std::vector<int>& extra_shape,
     bool input_as_shape,
-    const C10Tensor& values_,
-    BaseContext* context) {
+    const C10Tensor& values_) {
   Tensor output(output_);
   Tensor values(values_);
+  CPUContext context;
 
   filler_init(inputs, output_, shape, extra_shape, input_as_shape);
 
@@ -64,7 +64,7 @@ void given_tensor_fill_op_cpu_impl(
   auto* data = output.template mutable_data<Type>();
   const Type* values_data = values.template data<Type>();
   if (output.numel()) {
-    context->CopySameDevice(output.numel(), values_data, data);
+    context.CopySameDevice(output.numel(), values_data, data);
   }
 }
 
@@ -75,9 +75,9 @@ void constant_fill_op_cpu_impl(
     const std::vector<int>& extra_shape,
     bool input_as_shape,
     int dtype,
-    caffe2::ops::ConstantFill::Value value,
-    BaseContext* context) {
+    caffe2::ops::ConstantFill::Value value) {
   Tensor output(output_);
+  CPUContext context;
 
   filler_init(inputs, output_, shape, extra_shape, input_as_shape);
 
@@ -87,25 +87,25 @@ void constant_fill_op_cpu_impl(
           output.numel(),
           value.as_float,
           output.template mutable_data<float>(),
-          static_cast<CPUContext*>(context));
+          static_cast<CPUContext*>(&context));
     } else if (dtype == caffe2::TensorProto_DataType_INT32) {
       caffe2::math::Set<int32_t, CPUContext>(
           output.numel(),
           value.as_int32,
           output.template mutable_data<int32_t>(),
-          static_cast<CPUContext*>(context));
+          static_cast<CPUContext*>(&context));
     } else if (dtype == caffe2::TensorProto_DataType_INT64) {
       caffe2::math::Set<int64_t, CPUContext>(
           output.numel(),
           value.as_int64,
           output.template mutable_data<int64_t>(),
-          static_cast<CPUContext*>(context));
+          static_cast<CPUContext*>(&context));
     } else if (dtype == caffe2::TensorProto_DataType_BOOL) {
       caffe2::math::Set<bool, CPUContext>(
           output.numel(),
           value.as_bool,
           output.template mutable_data<bool>(),
-          static_cast<CPUContext*>(context));
+          static_cast<CPUContext*>(&context));
     } else {
       throw std::logic_error(
           "Unimplemented data type for ConstantFill: " +
@@ -121,9 +121,9 @@ void uniform_fill_op_cpu_impl(
     const std::vector<int>& extra_shape,
     bool input_as_shape,
     float min,
-    float max,
-    BaseContext* context) {
+    float max) {
   Tensor output(output_);
+  CPUContext context;
 
   filler_init(inputs, output_, shape, extra_shape, input_as_shape);
 
@@ -145,7 +145,7 @@ void uniform_fill_op_cpu_impl(
       min,
       max,
       output.template mutable_data<float>(),
-      static_cast<CPUContext*>(context));
+      static_cast<CPUContext*>(&context));
 }
 } // namespace
 } // namespace caffe2
index a1b1ae0..347ed88 100644 (file)
@@ -12,14 +12,14 @@ template <class DataType, class Context>
 void flatten_op_cpu_impl(
     const C10Tensor& input_,
     const C10Tensor& output_,
-    int axis,
-    BaseContext* context) {
+    int axis) {
   Tensor input(input_);
   Tensor output(output_);
+  CPUContext context;
   CAFFE_ENFORCE_GE(
       input.sizes().size(), axis, "The rank of the tensor must be >= axis.");
   output.Resize(input.size_to_dim(axis), input.size_from_dim(axis));
-  context->CopyItemsSameDevice(
+  context.CopyItemsSameDevice(
       input.dtype(),
       input.numel(),
       input.raw_data(),
index 1633449..568bf91 100644 (file)
@@ -16,11 +16,11 @@ void mul_op_cpu_impl(
     const C10Tensor& B_,
     const C10Tensor& C_,
     bool legacy_broadcast,
-    int axis,
-    BaseContext* context) {
+    int axis) {
   Tensor A(A_);
   Tensor B(B_);
   Tensor C(C_);
+  CPUContext context;
   const DataType* A_data = A.template data<DataType>();
   const DataType* B_data = B.template data<DataType>();
   std::vector<int> A_dims;
@@ -69,7 +69,7 @@ void mul_op_cpu_impl(
       A.data<DataType>(),
       B.data<DataType>(),
       C.mutable_data<DataType>(),
-      static_cast<CPUContext*>(context));
+      static_cast<CPUContext*>(&context));
 }
 } // namespace
 } // namespace caffe2
index 3f0fa0b..e4e4415 100644 (file)
@@ -11,12 +11,11 @@ namespace {
 template <class DataType>
 void stop_gradient_op_cpu_impl(
     const C10Tensor& input_,
-    const C10Tensor& output_,
-    BaseContext* context) {
+    const C10Tensor& output_) {
   Tensor input(input_);
   Tensor output(output_);
   if (output.getIntrusivePtr() != input.getIntrusivePtr()) {
-    output.CopyFrom(input, context);
+    output.CopyFrom(input);
   }
 }
 } // namespace
index b1334e0..75c4a97 100644 (file)
@@ -15,15 +15,14 @@ struct Add final {
       const C10Tensor& input2,
       const C10Tensor& output,
       bool legacy_broadcast,
-      int axis,
-      BaseContext* context);
+      int axis);
 
   static constexpr size_t num_dispatch_args() {return 2;}
 
   static constexpr size_t num_outputs() {return 1;}
 
-  static constexpr c10::guts::array<const char*, 6> parameter_names = {
-      {"input1", "input2", "output", "legacy_broadcast", "axis", "context"}};
+  static constexpr c10::guts::array<const char*, 5> parameter_names = {
+      {"input1", "input2", "output", "legacy_broadcast", "axis"}};
 };
 
 } // namespace ops
index 1f87651..8000181 100644 (file)
@@ -18,15 +18,14 @@ struct AveragedLoss final {
   using Signature = void(
       const C10Tensor& input,
       const C10Tensor& output,
-      State* state,
-      BaseContext* context);
+      State* state);
 
   static constexpr size_t num_dispatch_args() {return 1;}
 
   static constexpr size_t num_outputs() {return 1;}
 
-  static constexpr c10::guts::array<const char*, 4> parameter_names = {
-      {"input", "output", "state", "context"}};
+  static constexpr c10::guts::array<const char*, 3> parameter_names = {
+      {"input", "output", "state"}};
 };
 
 } // namespace ops
index e51c0a4..fc4f5cc 100644 (file)
@@ -13,15 +13,14 @@ struct BatchGather final {
   using Signature = void(
       const C10Tensor& data,
       const C10Tensor& indices,
-      const C10Tensor& output,
-      BaseContext* context);
+      const C10Tensor& output);
 
   static constexpr size_t num_dispatch_args() {return 2;}
 
   static constexpr size_t num_outputs() {return 1;}
 
-  static constexpr c10::guts::array<const char*, 4> parameter_names = {
-      {"data", "indices", "output", "context"}};
+  static constexpr c10::guts::array<const char*, 3> parameter_names = {
+      {"data", "indices", "output"}};
 };
 
 } // namespace ops
index 6788f4a..90d8b16 100644 (file)
@@ -21,22 +21,20 @@ struct BatchMatmul final {
       int trans_a,
       int trans_b,
       int broadcast,
-      State* state,
-      BaseContext* context);
+      State* state);
 
   static constexpr size_t num_dispatch_args() {return 2;}
 
   static constexpr size_t num_outputs() {return 1;}
 
-  static constexpr c10::guts::array<const char*, 8> parameter_names = {
+  static constexpr c10::guts::array<const char*, 7> parameter_names = {
       {"A",
        "B",
        "output",
        "trans_a",
        "trans_b",
        "broadcast",
-       "state",
-       "context"}};
+       "state"}};
 };
 
 } // namespace ops
index e739658..e41f62a 100644 (file)
@@ -17,21 +17,19 @@ struct Concat final {
       const C10Tensor& output,
       const C10Tensor& split_info,
       int add,
-      int add_axis,
-      BaseContext* context);
+      int add_axis);
 
   static constexpr size_t num_outputs() {return 2;}
 
-  static constexpr c10::guts::array<const char*, 6> parameter_names = {
-      {"inputs", "output", "split_info_output", "add", "add_axis", "context"}};
+  static constexpr c10::guts::array<const char*, 5> parameter_names = {
+      {"inputs", "output", "split_info_output", "add", "add_axis"}};
 
   static c10::DeviceTypeId dispatch_key(
       at::ArrayRef<C10Tensor> inputs,
       const C10Tensor& output,
       const C10Tensor& split_info,
       int add,
-      int add_axis,
-      BaseContext* context) {
+      int add_axis) {
     return c10::DeviceTypeId::CPU;
   }
 };
index 6684652..a472189 100644 (file)
@@ -19,15 +19,14 @@ struct ExpandDims final {
       const C10Tensor& input,
       const C10Tensor& output,
       const std::vector<int>& dims,
-      State* state,
-      BaseContext* context);
+      State* state);
 
   static constexpr size_t num_dispatch_args() {return 1;}
 
   static constexpr size_t num_outputs() {return 1;}
 
-  static constexpr c10::guts::array<const char*, 5> parameter_names = {
-      {"input", "output", "dims", "state", "context"}};
+  static constexpr c10::guts::array<const char*, 4> parameter_names = {
+      {"input", "output", "dims", "state"}};
 };
 
 } // namespace ops
index bea1353..9e6475b 100644 (file)
@@ -22,15 +22,14 @@ struct FullyConnected final {
       const C10Tensor& output,
       int axis,
       int axis_w,
-      Cache* cache,
-      BaseContext* context);
+      Cache* cache);
 
   static constexpr size_t num_dispatch_args() {return 3;}
 
   static constexpr size_t num_outputs() {return 1;}
 
-  static constexpr c10::guts::array<const char*, 8> parameter_names = {
-      {"X", "W", "b", "output", "axis", "axis_w", "cache", "context"}};
+  static constexpr c10::guts::array<const char*, 7> parameter_names = {
+      {"X", "W", "b", "output", "axis", "axis_w", "cache"}};
 };
 
 } // namespace ops
index dc81ca7..fa4ef51 100644 (file)
@@ -23,17 +23,15 @@ struct GivenTensorFill final {
       const std::vector<int64_t>& shape,
       const std::vector<int>& extra_shape,
       bool input_as_shape,
-      const C10Tensor& values,
-      BaseContext* context);
+      const C10Tensor& values);
 
-  static constexpr c10::guts::array<const char*, 7> parameter_names = {
+  static constexpr c10::guts::array<const char*, 6> parameter_names = {
       {"inputs",
        "output",
        "shape",
        "extra_shape",
        "input_as_shape",
-       "values",
-       "context"}};
+       "values"}};
 
    static constexpr size_t num_outputs() {return 1;}
 
@@ -43,8 +41,7 @@ struct GivenTensorFill final {
       const std::vector<int64_t>& shape,
       const std::vector<int>& extra_shape,
       bool input_as_shape,
-      const C10Tensor& values,
-      BaseContext* context) {
+      const C10Tensor& values) {
     return c10::DeviceTypeId::CPU;
   }
 };
@@ -65,20 +62,18 @@ struct ConstantFill final {
       const std::vector<int>& extra_shape,
       bool input_as_shape,
       int dtype,
-      Value value,
-      BaseContext* context);
+      Value value);
 
   static constexpr size_t num_outputs() {return 1;}
 
-  static constexpr c10::guts::array<const char*, 8> parameter_names = {
+  static constexpr c10::guts::array<const char*, 7> parameter_names = {
       {"inputs",
        "output",
        "shape",
        "extra_shape",
        "input_as_shape",
        "dtype",
-       "value",
-       "context"}};
+       "value"}};
 
   static c10::DeviceTypeId dispatch_key(
       at::ArrayRef<C10Tensor> inputs,
@@ -87,8 +82,7 @@ struct ConstantFill final {
       const std::vector<int>& extra_shape,
       bool input_as_shape,
       int dtype,
-      Value value,
-      BaseContext* context) {
+      Value value) {
     return c10::DeviceTypeId::CPU;
   }
 };
@@ -103,20 +97,18 @@ struct UniformFill final {
       const std::vector<int>& extra_shape,
       bool input_as_shape,
       float min,
-      float max,
-      BaseContext* context);
+      float max);
 
   static constexpr size_t num_outputs() {return 1;}
 
-  static constexpr c10::guts::array<const char*, 8> parameter_names = {
+  static constexpr c10::guts::array<const char*, 7> parameter_names = {
       {"inputs",
        "output",
        "shape",
        "extra_shape",
        "input_as_shape",
        "min",
-       "max",
-       "context"}};
+       "max"}};
 
   static c10::DeviceTypeId dispatch_key(
       at::ArrayRef<C10Tensor> inputs,
@@ -125,8 +117,7 @@ struct UniformFill final {
       const std::vector<int>& extra_shape,
       bool input_as_shape,
       float min,
-      float max,
-      BaseContext* context) {
+      float max) {
     return c10::DeviceTypeId::CPU;
   }
 };
index 26954d9..31622d6 100644 (file)
@@ -13,15 +13,14 @@ struct Flatten final {
   using Signature = void(
       const C10Tensor& input,
       const C10Tensor& output,
-      int axis,
-      BaseContext* context);
+      int axis);
 
   static constexpr size_t num_dispatch_args() {return 1;}
 
   static constexpr size_t num_outputs() {return 1;}
 
-  static constexpr c10::guts::array<const char*, 4> parameter_names = {
-      {"input", "output", "axis", "context"}};
+  static constexpr c10::guts::array<const char*, 3> parameter_names = {
+      {"input", "output", "axis"}};
 };
 
 } // namespace ops
index 92187dd..6d7bdff 100644 (file)
@@ -15,15 +15,14 @@ struct Mul final {
       const C10Tensor& input2,
       const C10Tensor& output,
       bool legacy_broadcast,
-      int axis,
-      BaseContext* context);
+      int axis);
 
   static constexpr size_t num_dispatch_args() {return 2;}
 
   static constexpr size_t num_outputs() {return 1;}
 
-  static constexpr c10::guts::array<const char*, 6> parameter_names = {
-      {"input1", "input2", "output", "legacy_broadcast", "axis", "context"}};
+  static constexpr c10::guts::array<const char*, 5> parameter_names = {
+      {"input1", "input2", "output", "legacy_broadcast", "axis"}};
 };
 
 } // namespace ops
index f38a4aa..7c17765 100644 (file)
@@ -12,15 +12,14 @@ struct StopGradient final {
 
   using Signature = void(
       const C10Tensor& input,
-      const C10Tensor& output,
-      BaseContext* context);
+      const C10Tensor& output);
 
   static constexpr size_t num_dispatch_args() {return 1;}
 
   static constexpr size_t num_outputs() {return 1;}
 
-  static constexpr c10::guts::array<const char*, 3> parameter_names = {
-      {"input", "output", "context"}};
+  static constexpr c10::guts::array<const char*, 2> parameter_names = {
+      {"input", "output"}};
 };
 
 } // namespace ops
index d0d874b..3d93b45 100644 (file)
@@ -2,7 +2,6 @@
 #include <c10/core/opschema/layer_norm.h>
 #include <torch/csrc/jit/custom_operator.h>
 #include <torch/csrc/autograd/variable.h>
-#include <caffe2/core/context.h>
 
 using c10::C10Tensor;