bidirectional_,
&context_);
- std::vector<Tensor> allOutputs(OutputSize());
- allOutputs.at(0) = copy_ctor(std::get<0>(results));
+ auto output = copy_ctor(std::get<0>(results));
if (batch_first_) {
- allOutputs.at(0) = transpose(allOutputs.at(0), 0, 1, &context_);
- }
- allOutputs.at(1) = copy_ctor(std::get<1>(results));
- allOutputs.at(2) = copy_ctor(std::get<2>(results));
- for (int i = 0; i < OutputSize(); i++) {
- auto output = XOutput(i, allOutputs.at(i).sizes(), dtype<float>());
- context_.CopyItemsSameDevice(
- allOutputs.at(i).dtype(),
- allOutputs.at(i).numel(),
- allOutputs.at(i).template data<float>(),
- output.template mutable_data<float>());
+ output = transpose(output, 0, 1, &context_);
}
+ SetOutputTensor(0, copy_ctor(output));
+ SetOutputTensor(1, copy_ctor(std::get<1>(results)));
+ SetOutputTensor(2, copy_ctor(std::get<2>(results)));
return true;
}