add ConvRelu schema (#18693)
authorJongsoo Park <jongsoo@fb.com>
Mon, 1 Apr 2019 20:02:02 +0000 (13:02 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 1 Apr 2019 20:09:07 +0000 (13:09 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18693

As title

Reviewed By: protonu

Differential Revision: D14662880

fbshipit-source-id: 3664faa660a04e1f528a413d2a1700b872c3c684

caffe2/quantization/server/conv_dnnlowp_op.cc
caffe2/quantization/server/conv_relu_op.cc

index bb2aeee..d585af5 100644 (file)
@@ -1534,9 +1534,6 @@ template class ConvDNNLowPOp<uint8_t, true>;
 template class ConvDNNLowPOp<uint16_t, false>;
 template class ConvDNNLowPOp<uint16_t, true>;
 
-OPERATOR_SCHEMA(ConvRelu).NumInputs(2, 3).NumOutputs(1).TensorInferenceFunction(
-    ConvPoolOpBase<CPUContext>::TensorInferenceForConv);
-
 REGISTER_CPU_OPERATOR_WITH_ENGINE(Conv, DNNLOWP, ConvDNNLowPOp<uint8_t, false>);
 REGISTER_CPU_OPERATOR_WITH_ENGINE(
     ConvRelu,
index 6668389..f3511a8 100644 (file)
@@ -64,6 +64,13 @@ bool ConvReluOp<T, Context>::RunOnDeviceWithOrderNHWC() {
   return true;
 }
 
+OPERATOR_SCHEMA(ConvRelu)
+    .NumInputs(2, 3)
+    .NumOutputs(1)
+    .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv)
+    .CostInferenceFunction(OpSchema::CostInferenceFunctionType(
+        ConvPoolOpBase<CPUContext>::CostInferenceForConv));
+
 REGISTER_CPU_OPERATOR(ConvRelu, ConvReluOp<float, CPUContext>);
 
 } // namespace caffe2