densities = ['Dense', 'Sparse']
extension_backends = ['MSNPU', 'XLA']
-# scalar_name, c_type, accreal, th_scalar_type, is_floating_type
+# scalar_name, c_type, accreal, is_floating_type
scalar_types = [
- ('Bool', 'uint8_t', 'BoolAccrealNotDefined', 'uint8_t', False),
- ('Byte', 'uint8_t', 'Long', 'uint8_t', False),
- ('Char', 'int8_t', 'Long', 'int8_t', False),
- ('Double', 'double', 'Double', 'double', True),
- ('Float', 'float', 'Double', 'float', True),
- ('Int', 'int', 'Long', 'int32_t', False),
- ('Long', 'int64_t', 'Long', 'int64_t', False),
- ('Short', 'int16_t', 'Long', 'int16_t', False),
- ('Half', 'Half', 'Double', 'at::Half', True),
+ ('Bool', 'uint8_t', 'BoolAccrealNotDefined', False),
+ ('Byte', 'uint8_t', 'Long', False),
+ ('Char', 'int8_t', 'Long', False),
+ ('Double', 'double', 'Double', True),
+ ('Float', 'float', 'Double', True),
+ ('Int', 'int', 'Long', False),
+ ('Long', 'int64_t', 'Long', False),
+ ('Short', 'int16_t', 'Long', False),
+ ('Half', 'Half', 'Double', True),
]
# shared environment for non-derived base classes Type.h Tensor.h Storage.h
def generate_storage_type_and_tensor(backend, density, scalar_type, declarations):
- scalar_name, c_type, accreal, th_scalar_type, is_floating_type = scalar_type
+ scalar_name, c_type, accreal, is_floating_type = scalar_type
env = {}
density_tag = 'Sparse' if density == 'Sparse' else ''
env['Density'] = density
env['ScalarName'] = scalar_name
env['ScalarType'] = c_type
- env['THScalarType'] = th_scalar_type
env['AccScalarName'] = accreal
env['isFloatingType'] = is_floating_type
env['isIntegralType'] = not is_floating_type
def generate_type_extension_backend_derived_types(backend):
env = {}
env['Backend'] = backend
- for scalar_name, c_type, _, _, _ in scalar_types:
+ for scalar_name, c_type, _, _ in scalar_types:
env['Type'] = "{}{}Type".format(backend, scalar_name)
env['ScalarName'] = scalar_name
env['ScalarType'] = c_type
def generate_legacy_th_dispatcher(backend, density, scalar_type, declarations):
assert density != 'Sparse'
- scalar_name, c_type, accreal, th_scalar_type, is_floating_type = scalar_type
+ scalar_name, c_type, accreal, is_floating_type = scalar_type
env = {}
env['Backend'] = backend
env['Dispatcher'] = "LegacyTH{}{}Dispatcher".format(backend, scalar_name)