Verify def before infer fensor (#18129)
authorGerard Goossen <ggoossen@fb.com>
Fri, 22 Mar 2019 13:33:24 +0000 (06:33 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Mar 2019 13:36:25 +0000 (06:36 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18129

A lot of tensor interference function assume the operator passes the schema.
So call Verity to make sure this is actually the case.

Created diff before to add checking in Concat (https://github.com/pytorch/pytorch/pull/17110), but I encountered lot more places where this is assumed (for example ElementwiseOpShapeInference)

Reviewed By: mdschatz

Differential Revision: D14503933

fbshipit-source-id: cf0097b8c3e4beb1cded6b61e092a6adee4b8fcb

caffe2/core/operator_schema.h
caffe2/python/operator_test/shape_inference_test.py

index 02eb2f5..e902fb5 100644 (file)
@@ -14,6 +14,7 @@
 #include "caffe2/core/logging.h"
 #include "caffe2/proto/caffe2_pb.h"
 #include "caffe2/utils/filler.h"
+#include "caffe2/utils/proto_utils.h"
 
 namespace caffe2 {
 
@@ -186,6 +187,10 @@ class CAFFE2_API OpSchema {
   inline vector<TensorShape> InferTensor(
       const OperatorDef& def,
       const vector<TensorShape>& input_type_shape) const {
+    CAFFE_ENFORCE(
+        Verify(def),
+        "(InferTensor) Operator def did not pass schema checking: ",
+        ProtoDebugString(def));
     return tensor_inference_function_(def, input_type_shape);
   }
 
index 36cfbac..a78d943 100644 (file)
@@ -415,8 +415,8 @@ class TestShapeInference(test_util.TestCase):
         net = core.Net("concat")
 
         net.Concat(["A", "B"], ["C", "splits"], axis=1)
-        net.Concat(["C", "D"], ["E"], order="NCHW")
-        net.Concat(["E", "F"], ["G"], add_axis=1, order="NHWC")
+        net.Concat(["C", "D"], ["E", "splitsE"], order="NCHW")
+        net.Concat(["E", "F"], ["G", "splitsG"], add_axis=1, order="NHWC")
         (shapes, types) = workspace.InferShapesAndTypes(
             [net],
             {
@@ -435,8 +435,8 @@ class TestShapeInference(test_util.TestCase):
         net = core.Net("concat")
 
         net.Concat(["A", "B"], ["C", "splits"], axis=1)
-        net.Concat(["C", "D"], ["E"], order="NCHW")
-        net.Concat(["E", "F"], ["G"], add_axis=1, order="NHWC")
+        net.Concat(["C", "D"], ["E", "splitsE"], order="NCHW")
+        net.Concat(["E", "F"], ["G", "splitsG"], add_axis=1, order="NHWC")
         (shapes, types) = workspace.InferShapesAndTypes(
             [net],
             blob_dimensions={