};
testSimple();
- auto testOne = [&](int ti, int tj, int toi, int toj) {
+ auto testOne = [&](int ti, int tj) {
Graph graph;
Var i0 = Var::asNewInput(graph);
float max_diff = (outputs.front() - out0).abs().max().item<double>();
ASSERT_TRUE(max_diff < 1e-6);
};
- testOne(0, 0, 0, 0);
- testOne(0, 1, 0, 0);
- testOne(1, 2, 0, 0);
- testOne(0, 2, 0, 0);
-
- testOne(0, 0, 0, 1);
- testOne(0, 1, 1, 2);
- testOne(1, 2, 0, 2);
+ testOne(0, 0);
+ testOne(0, 1);
+ testOne(1, 2);
+ testOne(0, 2);
auto createFusedConcat =
[](Graph& graph, at::ArrayRef<Value*> inputs, int64_t dim) -> Value* {