%ret: Tensor = aten::cat(%ten_list2, %dim)
return (%ret)
)IR";
+
+const auto cumsum_script = R"JIT(
+ def forward(self, a: Tensor, dim: int):
+ return torch.cumsum(a, dim).clone()
+)JIT";
+
+const auto cumsum_script_dtype = R"JIT(
+ def forward(self, a: Tensor, dim: int, dtype: int):
+ return torch.cumsum(a, dim, dtype=dtype).clone()
+)JIT";
std::vector<IValue> args1{c, d, 1};
testStaticRuntime(cat_script, args0, args1);
}
+
+
+TEST(StaticRuntime, IndividualOps_Cumsum) {
+ auto a = at::randn({2, 3});
+ std::vector<IValue> args0{a, 0};
+ testStaticRuntime(cumsum_script, args0);
+
+ auto b = at::randn({4, 3});
+ std::vector<IValue> args1{b, 1};
+ testStaticRuntime(cumsum_script, args0, args1);
+}
+
+TEST(StaticRuntime, IndividualOps_CumsumDtype) {
+ auto a = at::randn({1, 2});
+ auto dtype = at::ScalarType::Float;
+ std::vector<IValue> args0{a, 0, dtype};
+ testStaticRuntime(cumsum_script_dtype, args0);
+
+ auto b = at::randn({3, 4});
+ std::vector<IValue> args1{b, 1, dtype};
+ testStaticRuntime(cumsum_script_dtype, args0, args1);
+}
};
});
+REGISTER_OPERATOR_FUNCTOR(aten::cumsum, aten_cumsum, [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "aten::cumsum(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ const auto& input = p_node->Input(0).toTensor();
+ const auto dim = p_node->Input(1).toInt();
+ const auto dtype = p_node->Input(2).toOptional<c10::ScalarType>();
+
+ if (p_node->Output(0).isNone()) {
+ p_node->Output(0) = at::cpu::cumsum(input, dim, dtype);
+ return;
+ }
+
+ auto& output = p_node->Output(0).toTensor();
+ fastResizeToZero(output);
+ at::cpu::cumsum_out(output, input, dim, dtype);
+ };
+});
+
namespace {
void check_cat_no_zero_dim(const std::vector<at::Tensor>& tensors) {