ge = self.checkTrace(scaleshift, inputs)
self.assertExpectedGraph(ge.graph_for(*inputs))
+ @staticmethod
+ def _test_cast_Float(self, device):
+ def f(x, y):
+ z = x.float()
+ return z + y
+
+ inputs = [
+ torch.randn(4, 4, dtype=torch.double, device=device),
+ torch.randn(4, 4, dtype=torch.float, device=device),
+ ]
+
+ ge = self.checkScript(f, inputs)
+ self.assertAllFused(ge.graph_for(*inputs))
+
+ @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+ @enable_cpu_fuser
+ def test_cast_Float(self):
+ return self._test_cast_Float(self, 'cpu')
+
+ @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+ @unittest.skipIf(not RUN_CUDA, "No CUDA")
+ @skipIfRocm
+ def test_cast_Float_cuda(self):
+ return self._test_cast_Float(self, 'cuda')
+
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(not RUN_CUDA_HALF, "no half support")
return "int";
} else if (t->kind() == TypeKind::FloatType) {
return "float";
+ } else if (t->kind() == TypeKind::BoolType) {
+ return "bool";
} else if (t->kind() == TypeKind::TensorType) {
auto const tt = t->cast<TensorType>();
return calcScalarTypeName(tt->scalarType());
}
static std::string typeCastedValueName(const std::shared_ptr<c10::Type> t, const at::ScalarType outtype, const std::string& vn) {
- if (t->kind() == TypeKind::IntType) {
+ if (t->kind() == TypeKind::IntType || t->kind() == TypeKind::BoolType) {
if (! isIntegralType(outtype)) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}