Validate matching input shapes in Int8Add operator (#14520)
authorMarat Dukhan <marat@fb.com>
Wed, 5 Dec 2018 19:39:46 +0000 (11:39 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 20:00:23 +0000 (12:00 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14520

Default engine doesn't support broadcast semantics in Int8Add operator. This patch adds a check that shapes are equivalent.

Reviewed By: bertmaher

Differential Revision: D13250922

fbshipit-source-id: 8526d07723bd9a34d54dee04d121c57f8b33c481

caffe2/operators/quantized/int8_add_op.h

index 5689880..4f30195 100644 (file)
@@ -32,6 +32,11 @@ class Int8AddOp final : public Operator<CPUContext> {
     const auto& B = Inputs()[1]->template Get<Int8TensorCPU>();
     auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
 
+    CAFFE_ENFORCE_EQ(
+        A.t.sizes(),
+        B.t.sizes(),
+        "inputs must have the same shape (broadcast semantics is not supported)");
+
     /*
      * Record quantization parameters for A and B inputs, because if the op is
      * in-place, we may overwrite these parameters later, when we set