Pass IValue from c10 dispatcher to caffe2 operator (#16065)
authorSebastian Messmer <messmer@fb.com>
Fri, 18 Jan 2019 23:55:57 +0000 (15:55 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 19 Jan 2019 00:02:18 +0000 (16:02 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16065

Before, we registered the caffe2 kernel with the c10 dispatcher using plain C types.
Now, we pass in IValues, which avoids the unwrapping inbetween.

Reviewed By: ezyang

Differential Revision: D13689036

fbshipit-source-id: b976a2c46a5a541f6a926b3df255e8a535e32420

caffe2/operators/layer_norm_op.cc

index ca39be5..b4180a6 100644 (file)
@@ -187,20 +187,15 @@ to the end.)
 // Register layer norm with c10
 namespace {
 template <class DataType>
-void layer_norm_c10(
-    const at::Tensor& X_,
-    const at::Tensor& Y_,
-    const at::Tensor& mean_,
-    const at::Tensor& sig_,
-    int axis,
-    float epsilon,
-    c10::intrusive_ptr<caffe2::Blob> cache_) {
-  caffe2::Tensor X{c10::C10Tensor(X_)};
-  caffe2::Tensor Y{c10::C10Tensor(Y_)};
-  caffe2::Tensor mean{c10::C10Tensor(mean_)};
-  caffe2::Tensor sig{c10::C10Tensor(sig_)};
+c10::IValue layer_norm_c10(c10::ArrayRef<c10::IValue> inputs) {
+  caffe2::Tensor X{c10::C10Tensor(inputs[0].toTensor())};
+  caffe2::Tensor Y{c10::C10Tensor(inputs[1].toTensor())};
+  caffe2::Tensor mean{c10::C10Tensor(inputs[2].toTensor())};
+  caffe2::Tensor sig{c10::C10Tensor(inputs[3].toTensor())};
+  int64_t axis = inputs[4].toInt();
+  float epsilon = inputs[5].toDouble();
   caffe2::CPUContext context;
-  c10::core::opschema::LayerNorm::Cache* cache = cache_->GetMutable<c10::core::opschema::LayerNorm::Cache>();
+  c10::core::opschema::LayerNorm::Cache* cache = inputs[6].toBlob()->GetMutable<c10::core::opschema::LayerNorm::Cache>();
   if (!cache->scale.has_value()) {
     cache->scale = at::Tensor(c10::C10Tensor(caffe2::Tensor{caffe2::CPU}));
   }
@@ -219,6 +214,7 @@ void layer_norm_c10(
   caffe2::LayerNormOp<caffe2::CPUContext>::runLayerNorm<DataType>(
     X, &Y, &mean, &sig, canonical_axis, epsilon, &scale, &bias, static_cast<caffe2::CPUContext*>(&context)
   );
+  return c10::IValue();
 }
 }
 namespace c10 {