@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@skipIfRocm
+ def test_addcmul_cuda(self):
+ t = torch.randn(1, 4, dtype=torch.float, device='cuda')
+ t1 = torch.randn(4, 1, dtype=torch.float, device='cuda')
+ t2 = torch.randn(1, 4, dtype=torch.float, device='cuda')
+
+ def foo(t, t1, t2):
+ return t.addcmul(t + 1, t2, value=0.1)
+
+ ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True)
+ graph = ge.graph_for(t, t1, t2)
+ self.assertAllFused(graph)
+
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_lerp_cuda(self):
+ start = torch.randn(4, 1, dtype=torch.float, device='cuda')
+ end = torch.randn(1, 4, dtype=torch.float, device='cuda')
+ weight = torch.tensor(0.5, dtype=torch.float, device='cuda')
+
+ # scalar weight overload
+ def foo_weight_scalar(start, end):
+ return torch.lerp(start + 1, end, 0.5)
+
+ # tensor weight overload
+ def foo_weight_tensor(start, end):
+ return torch.lerp(start + 1, end, weight)
+
+ ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end))
+ graph = ge_weight_scalar.graph_for(start, end)
+ self.assertAllFused(graph)
+
+ ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end))
+ graph = ge_weight_tensor.graph_for(start, end)
+ self.assertAllFused(graph)
+
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
def test_concat_cuda(self):
hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
return true;
if (n->matches(
- "aten::dropout(Tensor input, float p, bool train) -> Tensor", attr::train)) {
+ "aten::dropout(Tensor input, float p, bool train) -> Tensor",
+ attr::train)) {
return n->get<bool>(attr::train).value();
}
static_cast<bool (*)(Node*)>(isDifferentiable));
}
+ // formulas are only defined with floating point scalars,
+ // so we fallback to autograd for other cases.
+ for (const Value* input : n->inputs()) {
+ if (input->type() == NumberType::get()) {
+ return false;
+ }
+ }
+
return false;
}
Node* node,
const ArrayRef<Value*>& grads) {
auto graph = node->owningGraph();
-
auto compiled_graphs = gradientInfoForSchema(node->schema());
if (!compiled_graphs) {
return c10::nullopt;
auto bw_graph = compiled_graphs->backward;
auto grad_vec = grads.vec();
if (needTrimGrad(node)) {
- grad_vec.erase(grad_vec.begin()+1, grad_vec.end());
+ grad_vec.erase(grad_vec.begin() + 1, grad_vec.end());
}
auto it = grad_vec.begin();
grad_vec.insert(it, new_outputs.back());
{aten::__or__, "${0} || ${1}"},
{aten::__rshift__, "${0} >> ${1}"},
{aten::__xor__, "${0} ^ ${1}"},
+ {aten::addcmul, "${cast_0} + ${cast_3} * ${cast_1} * ${cast_2}"},
{aten::div, "${cast_0} / ${cast_1}"},
{aten::eq, "${0} == ${1}"},
{aten::fmod, "fmodf(${cast_0}, ${cast_1})"},
{aten::gt, "${0} > ${1}"},
{aten::le, "(${0} <= ${1})"},
{aten::lt, "${0} < ${1}"},
+ {aten::lerp, "${cast_0} + ${cast_2} * (${cast_1} - ${cast_0})"},
{aten::type_as, "(${cast_0})"},
{aten::mul, "${cast_0} * ${cast_1}"},
{aten::ne, "${0} != ${1}"},
#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/resource_guard.h>
-#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/passes/batch_mm.h>
#include <torch/csrc/jit/passes/canonicalize_ops.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
+#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/symbolic_variable.h>
#include <torch/csrc/jit/tracer.h>
"aten::log10(Tensor self) -> Tensor",
"aten::log1p(Tensor self) -> Tensor",
"aten::log2(Tensor self) -> Tensor",
+ "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
+ "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor",
"aten::max(Tensor self, Tensor other) -> Tensor",
"aten::min(Tensor self, Tensor other) -> Tensor",
"aten::mul(Tensor self, Tensor other) -> Tensor",
"aten::lt(Tensor self, Tensor other) -> Tensor",
"aten::lt(Tensor self, Scalar other) -> Tensor",
+ "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor",
"aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
"aten::type_as(Tensor self, Tensor other) -> Tensor",
return torch.matmul(self, other), backward
)",
R"(
+ def addcmul(self,
+ tensor1,
+ tensor2,
+ *,
+ value: float = 1.0):
+ def backward(grad_output):
+ grad = grad_output * value
+ grad_tensor1 = (grad * tensor2)._grad_sum_to_size(tensor1.size())
+ grad_tensor2 = (grad * tensor1)._grad_sum_to_size(tensor2.size())
+ return grad_output._grad_sum_to_size(self.size()), grad_tensor1, grad_tensor2, None
+ return torch.addcmul(self, tensor1, tensor2, value=value), backward
+
def _dim_arange(like,
dim: int):
def backward(grad_output):
return torch.full_like(self, fill_value), backward
+ def lerp_0(self,
+ end,
+ weight: float):
+ def backward(grad_output):
+ grad_self = (grad_output * (1 - weight))._grad_sum_to_size(self.size())
+ grad_end = (grad_output * weight)._grad_sum_to_size(end.size())
+ return grad_self, grad_end, None
+ return torch.lerp(self, end, weight), backward
+
+ def lerp_1(self,
+ end,
+ weight):
+ def backward(grad_output):
+ grad_self = (grad_output * (1 - weight))._grad_sum_to_size(self.size())
+ grad_end = (grad_output * weight)._grad_sum_to_size(end.size())
+ return grad_self, grad_end, None
+ return torch.lerp(self, end, weight), backward
+
def mul(self, other):
def backward(grad_output):
# self & other are used in backward. No need to pass in their size
schema_name.length(),
schema_name.substr(0, pos));
}
+
return schema_string;
}
return cache_it->second;
} else {
auto schema_str = canonicalSchemaString(schema);
+ // Specialize Scalar to float for the arg type of the node schema
+ // this is used to:
+ // 1. define scalar type as float in TorchScript autodiff formula
+ // 2. to make sure the input of any graph node does not contain scalar type
+ // in its argument, all scalar arg should already be passed with float
+ // value since scalar/int aren't differentiable either way.
+ //
+ c10::ReplaceAll(schema_str, "Scalar", "float");
+
auto sym_script_it = schema_to_graphs.find(schema_str);
if (sym_script_it != schema_to_graphs.end()) {
// merged. Ideally this should all go into native_functions.yaml
#include <c10/util/Optional.h>
+#include <c10/util/StringUtil.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/script/module.h>