From b161ac9634d11ae3c6a203b22e057cce9b7137d7 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Mon, 11 Mar 2019 21:01:21 -0700 Subject: [PATCH] Small clean up of aten_op Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17530 Reviewed By: ezyang Differential Revision: D14237931 fbshipit-source-id: fb73d63d89fab0622097a49be6ed0b75ddb02a7c --- caffe2/contrib/aten/aten_op_template.h | 40 +++------------------------------- caffe2/contrib/aten/gen_op.py | 19 +++------------- 2 files changed, 6 insertions(+), 53 deletions(-) diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index 626534e..256b846 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -55,7 +55,7 @@ private: } at::Type& typeFor(const Tensor& ten) { - return at::getNonVariableType(backend(), atScalarTypeFor(ten.meta())); + return at::getNonVariableType(backend(), typeMetaToScalarType(ten.meta())); } at::Tensor tensorWrapping(const Tensor& ten_) { auto& ten = const_cast(ten_); @@ -75,19 +75,6 @@ private: return results; } - at::ScalarType atScalarTypeFor(const TypeMeta & meta) { - #define DEFINE_IF(ctype,aten_name,_) \ - if(meta.Match()) { \ - return at::k##aten_name; \ - } - AT_FORALL_SCALAR_TYPES(DEFINE_IF) - #undef DEFINE_IF - // Special case for bool, since the type in ATen is actually Byte - if (meta.Match()) { - return at::kByte; - } - CAFFE_THROW("Unknown type meta"); // TODO: improve error message... - } void assignTo(Tensor* dst, const at::Tensor& src_) { at::Tensor src = src_.contiguous(); auto at_sizes = src.sizes(); @@ -129,8 +116,8 @@ private: return s.toLong(); } - void assignTo(Tensor* dst, at::Type& inferred_type, at::Scalar scalar) { - switch(inferred_type.scalarType()) { + void assignTo(Tensor* dst, at::ScalarType scalar_type, at::Scalar scalar) { + switch(scalar_type) { #define DEFINE_CASE(ctype,aten_name,native) \ case at::k##aten_name: { \ auto value = extract_##native(scalar); \ @@ -208,27 +195,6 @@ private: } return result; } - at::ScalarType stringToScalarType(const std::string & name) { - #define DEFINE_IF(type,aten) \ - if(#type == name) \ - return at::k##aten; - DEFINE_IF(at::Half, Half) - DEFINE_IF(float, Float) - DEFINE_IF(double, Double) - DEFINE_IF(uint8, Byte) - DEFINE_IF(int8, Char) - DEFINE_IF(int16, Short) - DEFINE_IF(int32, Int) - DEFINE_IF(int64, Long) - CAFFE_THROW("unsupported type annotation: ", name); - } - at::TypeExtendedInterface & stringToType(const std::string & name) { - return at::getNonVariableType(backend(), stringToScalarType(name)); - } - at::TypeExtendedInterface * readTypeAttribute(const std::string & name) { - CAFFE_ENFORCE(OperatorBase::HasSingleArgumentOfType(name)); - return &stringToType(OperatorBase::GetSingleArgument(name, "")); - } }; } diff --git a/caffe2/contrib/aten/gen_op.py b/caffe2/contrib/aten/gen_op.py index a582727..47794d3 100755 --- a/caffe2/contrib/aten/gen_op.py +++ b/caffe2/contrib/aten/gen_op.py @@ -73,7 +73,7 @@ def value_is_tensor_type(v): # for each aten type, how do we handle a return value of that type? RETURN_MAP = { 'Tensor': 'assignTo(Output(${offset}),${output});', - 'Scalar': 'assignTo(Output(${offset}),*inferred_type, ${output});', + 'Scalar': 'assignTo(Output(${offset}),self.scalar_type(), ${output});', 'bool': 'assignToValue(Output(${offset}),${output});', 'int64_t': 'assignToValue(Output(${offset}),${output});', 'std::vector': 'assignListStartingAt(${offset}, ${output});', @@ -235,16 +235,12 @@ if __name__ == '__main__': 'initialization': [], 'key': str(key), } - defined_inferred_type = False if 'namespace' not in o['method_of'] and 'Tensor' not in o['method_of']: # methods on type like 'ones' or 'zeros' always take a # string attribute that is translated into the at::Type object # e.g. "Float" is at::kFloat assert('Type' in o['method_of']) - defined_inferred_type = True - env['initialization'].append( - 'auto inferred_type = readTypeAttribute("type");') static_tensor_inputs = sum(arg['type'] != 'TensorList' and value_is_tensor_type(arg) for arg in o['arguments']) has_tensorlist = any(arg['type'] == 'TensorList' for arg in o['arguments']) @@ -269,12 +265,6 @@ if __name__ == '__main__': env['statements'].append( 'auto {} = peek({}, {});'.format(arg['name'], real_inputs, view_length)) real_inputs += 1 - if arg['dynamic_type'] == 'Tensor' and not defined_inferred_type: - # first tensor input is used to define the output type. - defined_inferred_type = True - env['statements'].append( - 'auto inferred_type = &at::getType({});'.format( - arg['name'])) else: init = CT(ARGUMENT_MAP[arg['type']]).substitute(env, arg=arg['name']) env['initialization'].append(init) @@ -286,13 +276,10 @@ if __name__ == '__main__': if 'namespace' in o['method_of']: env['invocation'] = CT("at::${name}(${arguments})").substitute(env) - elif 'Tensor' in o['method_of']: + else: + assert('Tensor' in o['method_of']) env['invocation'] = "self.{}({})".format( o['name'], ', '.join(env['arguments'][1:])) - else: - assert('Type' in o['method_of']) - env['invocation'] = CT( - 'inferred_type->${name}(${arguments})').substitute(env) top_env['implementations'].append(OPTION_TEMPLATE.substitute(env)) key += 1 -- 2.7.4