// Register layer norm with c10
namespace {
template <class DataType>
-void layer_norm_c10(
- const at::Tensor& X_,
- const at::Tensor& Y_,
- const at::Tensor& mean_,
- const at::Tensor& sig_,
- int axis,
- float epsilon,
- c10::intrusive_ptr<caffe2::Blob> cache_) {
- caffe2::Tensor X{c10::C10Tensor(X_)};
- caffe2::Tensor Y{c10::C10Tensor(Y_)};
- caffe2::Tensor mean{c10::C10Tensor(mean_)};
- caffe2::Tensor sig{c10::C10Tensor(sig_)};
+c10::IValue layer_norm_c10(c10::ArrayRef<c10::IValue> inputs) {
+ caffe2::Tensor X{c10::C10Tensor(inputs[0].toTensor())};
+ caffe2::Tensor Y{c10::C10Tensor(inputs[1].toTensor())};
+ caffe2::Tensor mean{c10::C10Tensor(inputs[2].toTensor())};
+ caffe2::Tensor sig{c10::C10Tensor(inputs[3].toTensor())};
+ int64_t axis = inputs[4].toInt();
+ float epsilon = inputs[5].toDouble();
caffe2::CPUContext context;
- c10::core::opschema::LayerNorm::Cache* cache = cache_->GetMutable<c10::core::opschema::LayerNorm::Cache>();
+ c10::core::opschema::LayerNorm::Cache* cache = inputs[6].toBlob()->GetMutable<c10::core::opschema::LayerNorm::Cache>();
if (!cache->scale.has_value()) {
cache->scale = at::Tensor(c10::C10Tensor(caffe2::Tensor{caffe2::CPU}));
}
caffe2::LayerNormOp<caffe2::CPUContext>::runLayerNorm<DataType>(
X, &Y, &mean, &sig, canonical_axis, epsilon, &scale, &bias, static_cast<caffe2::CPUContext*>(&context)
);
+ return c10::IValue();
}
}
namespace c10 {