* @param iteration iteration
* @return nntrainer::sharedConstTensor
*/
- void forward(int iteration, NodeWatcher &next_node);
+ void forward(int iteration, bool verify);
/**
* @brief backward pass of the node with verifying inputs/gradients/outputs
*/
bool isLossType() { return node->requireLabel(); }
+ /**
+ * @brief support in-place operation
+ *
+ * @return true if support in-place else false
+ */
+ bool supportInPlace() { return node->supportInPlace(); }
+
+ /**
+ * @brief support backwarding operation
+ *
+ * @return true if support backwarding else false
+ */
+ bool supportBackwarding() { return node->supportBackwarding(); }
+
private:
NodeType node;
std::vector<nntrainer::Tensor> expected_output;
}
}
-void NodeWatcher::forward(int iteration, NodeWatcher &next_node) {
+void NodeWatcher::forward(int iteration, bool verify_forward) {
std::stringstream ss;
ss << "forward failed at " << node->getName() << " at iteration "
<< iteration;
out.push_back(node->getOutput(idx));
}
- if (!next_node.node->supportInPlace() &&
- getNodeType() != nntrainer::MultiOutLayer::type)
+ if (verify_forward && getNodeType() != nntrainer::MultiOutLayer::type)
verify(out, expected_output, err_msg + " at output");
}
nntrainer::Tensor::epsilon);
}
- for (auto it = nodes.begin(); it != nodes.end() - 1; ++it) {
- it->forward(iteration, *(it + 1));
+ auto it = nodes.begin();
+ for (; it != nodes.end() - 1; ++it) {
+ it->forward(iteration, !(it + 1)->supportInPlace());
}
+ it->forward(iteration, true);
if (loss_nodes.size()) {
nn.backwarding(label, iteration);
- for (auto it = nodes.rbegin(); it != nodes.rend() - 1; it++) {
- if (it == nodes.rend())
- it->backward(iteration, true, !optimize);
- else
+ for (auto it = nodes.rbegin(); it != nodes.rend(); it++) {
+ if (it == nodes.rend() - 1) {
+ /** check last layer backwarding only when not input layers */
+ if (it->supportBackwarding())
+ it->backward(iteration, true, !optimize);
+ } else {
it->backward(iteration, !optimize, !optimize);
+ }
}
} else {
EXPECT_THROW(nn.backwarding(label, iteration), std::runtime_error);