Small clean up of aten_op
authorRoy Li <royboy@fb.com>
Tue, 12 Mar 2019 04:01:21 +0000 (21:01 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 12 Mar 2019 04:04:16 +0000 (21:04 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17530

Reviewed By: ezyang

Differential Revision: D14237931

fbshipit-source-id: fb73d63d89fab0622097a49be6ed0b75ddb02a7c

caffe2/contrib/aten/aten_op_template.h
caffe2/contrib/aten/gen_op.py

index 626534e..256b846 100644 (file)
@@ -55,7 +55,7 @@ private:
   }
 
   at::Type& typeFor(const Tensor& ten) {
-    return at::getNonVariableType(backend(), atScalarTypeFor(ten.meta()));
+    return at::getNonVariableType(backend(), typeMetaToScalarType(ten.meta()));
   }
   at::Tensor tensorWrapping(const Tensor& ten_) {
     auto& ten = const_cast<Tensor&>(ten_);
@@ -75,19 +75,6 @@ private:
     return results;
   }
 
-  at::ScalarType atScalarTypeFor(const TypeMeta & meta) {
-    #define DEFINE_IF(ctype,aten_name,_) \
-    if(meta.Match<ctype>()) { \
-      return at::k##aten_name; \
-    }
-    AT_FORALL_SCALAR_TYPES(DEFINE_IF)
-    #undef DEFINE_IF
-    // Special case for bool, since the type in ATen is actually Byte
-    if (meta.Match<bool>()) {
-      return at::kByte;
-    }
-    CAFFE_THROW("Unknown type meta"); // TODO: improve error message...
-  }
   void assignTo(Tensor* dst, const at::Tensor& src_) {
     at::Tensor src = src_.contiguous();
     auto at_sizes = src.sizes();
@@ -129,8 +116,8 @@ private:
     return s.toLong();
   }
 
-  void assignTo(Tensor* dst, at::Type& inferred_type, at::Scalar scalar) {
-    switch(inferred_type.scalarType()) {
+  void assignTo(Tensor* dst, at::ScalarType scalar_type, at::Scalar scalar) {
+    switch(scalar_type) {
       #define DEFINE_CASE(ctype,aten_name,native) \
         case at::k##aten_name: { \
           auto value = extract_##native(scalar); \
@@ -208,27 +195,6 @@ private:
     }
     return result;
   }
-  at::ScalarType stringToScalarType(const std::string & name) {
-    #define DEFINE_IF(type,aten) \
-      if(#type == name) \
-        return at::k##aten;
-    DEFINE_IF(at::Half, Half)
-    DEFINE_IF(float, Float)
-    DEFINE_IF(double, Double)
-    DEFINE_IF(uint8, Byte)
-    DEFINE_IF(int8, Char)
-    DEFINE_IF(int16, Short)
-    DEFINE_IF(int32, Int)
-    DEFINE_IF(int64, Long)
-    CAFFE_THROW("unsupported type annotation: ", name);
-  }
-  at::TypeExtendedInterface & stringToType(const std::string & name) {
-    return at::getNonVariableType(backend(), stringToScalarType(name));
-  }
-  at::TypeExtendedInterface * readTypeAttribute(const std::string & name) {
-    CAFFE_ENFORCE(OperatorBase::HasSingleArgumentOfType<std::string>(name));
-    return &stringToType(OperatorBase::GetSingleArgument<std::string>(name, ""));
-  }
 };
 
 }
index a582727..47794d3 100755 (executable)
@@ -73,7 +73,7 @@ def value_is_tensor_type(v):
 # for each aten type, how do we handle a return value of that type?
 RETURN_MAP = {
     'Tensor': 'assignTo(Output(${offset}),${output});',
-    'Scalar': 'assignTo(Output(${offset}),*inferred_type, ${output});',
+    'Scalar': 'assignTo(Output(${offset}),self.scalar_type(), ${output});',
     'bool': 'assignToValue<int64_t>(Output(${offset}),${output});',
     'int64_t': 'assignToValue<int64_t>(Output(${offset}),${output});',
     'std::vector<Tensor>': 'assignListStartingAt(${offset}, ${output});',
@@ -235,16 +235,12 @@ if __name__ == '__main__':
             'initialization': [],
             'key': str(key),
         }
-        defined_inferred_type = False
 
         if 'namespace' not in o['method_of'] and 'Tensor' not in o['method_of']:
             # methods on type like 'ones' or 'zeros' always take a
             # string attribute that is translated into the at::Type object
             # e.g. "Float" is at::kFloat
             assert('Type' in o['method_of'])
-            defined_inferred_type = True
-            env['initialization'].append(
-                'auto inferred_type = readTypeAttribute("type");')
 
         static_tensor_inputs = sum(arg['type'] != 'TensorList' and value_is_tensor_type(arg) for arg in o['arguments'])
         has_tensorlist = any(arg['type'] == 'TensorList' for arg in o['arguments'])
@@ -269,12 +265,6 @@ if __name__ == '__main__':
                 env['statements'].append(
                     'auto {} = peek({}, {});'.format(arg['name'], real_inputs, view_length))
                 real_inputs += 1
-                if arg['dynamic_type'] == 'Tensor' and not defined_inferred_type:
-                    # first tensor input is used to define the output type.
-                    defined_inferred_type = True
-                    env['statements'].append(
-                        'auto inferred_type = &at::getType({});'.format(
-                            arg['name']))
             else:
                 init = CT(ARGUMENT_MAP[arg['type']]).substitute(env, arg=arg['name'])
                 env['initialization'].append(init)
@@ -286,13 +276,10 @@ if __name__ == '__main__':
 
         if 'namespace' in o['method_of']:
             env['invocation'] = CT("at::${name}(${arguments})").substitute(env)
-        elif 'Tensor' in o['method_of']:
+        else:
+            assert('Tensor' in o['method_of'])
             env['invocation'] = "self.{}({})".format(
                 o['name'], ', '.join(env['arguments'][1:]))
-        else:
-            assert('Type' in o['method_of'])
-            env['invocation'] = CT(
-                'inferred_type->${name}(${arguments})').substitute(env)
 
         top_env['implementations'].append(OPTION_TEMPLATE.substitute(env))
         key += 1