self.fc3 = torch.nn.Linear(10, 2)
self.loss = torch.nn.MSELoss()
# determine which layer to set to non-trainable
- if idx == 1:
- for param in self.fc1.parameters():
- param.requires_grad = False
- elif idx == 2:
- for param in self.fc2.parameters():
- param.requires_grad = False
+ fc_layer_list = [self.fc1, self.fc2, self.fc3]
+ for param in fc_layer_list[idx-1].parameters():
+ param.requires_grad = False
def forward(self, inputs, labels):
out = torch.relu(self.fc1(inputs[0]))
label_dims=[(3,2)],
name="non_trainable_fc_idx2"
)
+
+ non_trainable_fc_idx3 = NonTrainableFC(idx=3)
+ record_v2(
+ non_trainable_fc_idx3,
+ iteration=2,
+ input_dims=[(3,3)],
+ input_dtype=[float],
+ label_dims=[(3,2)],
+ name="non_trainable_fc_idx3"
+ )
# Function to check the created golden test file
- inspect_file("fc_relu_decay.nnmodelgolden")
+ inspect_file("non_trainable_fc_idx3.nnmodelgolden")
std::string fc1_trainable = (idx == 1) ? "trainable=false" : "trainable=true";
std::string fc2_trainable = (idx == 2) ? "trainable=false" : "trainable=true";
+ std::string fc3_trainable = (idx == 3) ? "trainable=false" : "trainable=true";
- return [fc1_trainable, fc2_trainable]() {
+ return [fc1_trainable, fc2_trainable, fc3_trainable]() {
std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
nn->setProperty({"batch_size=3"});
auto outer_graph = makeGraph({
{"input", {"name=in", "input_shape=1:1:3"}},
{"fully_connected",
- {"name=fc1", "input_layers=in", "unit=10", fc1_trainable}},
- {"activation", {"name=act1", "input_layers=fc1", "activation=relu"}},
+ {"name=fc1", "input_layers=in", "unit=10", "activation=relu",
+ fc1_trainable}},
{"fully_connected",
- {"name=fc2", "input_layers=act1", "unit=10", fc2_trainable}},
- {"activation", {"name=act2", "input_layers=fc2", "activation=relu"}},
- {"fully_connected", {"name=fc3", "input_layers=act2", "unit=2"}},
- {"activation", {"name=act3", "input_layers=fc3", "activation=sigmoid"}},
- {"mse", {"name=loss", "input_layers=act3"}},
+ {"name=fc2", "input_layers=fc1", "unit=10", "activation=relu",
+ fc2_trainable}},
+ {"fully_connected",
+ {"name=fc3", "input_layers=fc2", "unit=2", "activation=sigmoid",
+ fc3_trainable}},
+ {"mse", {"name=loss", "input_layers=fc3"}},
});
for (auto &node : outer_graph) {
static auto makeNonTrainableFcIdx1 = getFuncToMakeNonTrainableFc(1);
static auto makeNonTrainableFcIdx2 = getFuncToMakeNonTrainableFc(2);
+static auto makeNonTrainableFcIdx3 = getFuncToMakeNonTrainableFc(3);
static std::unique_ptr<NeuralNetwork> makeMolAttention() {
std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
ModelTestOption::ALL_V2),
mkModelTc_V2(makeNonTrainableFcIdx2, "non_trainable_fc_idx2",
ModelTestOption::ALL_V2),
+ mkModelTc_V2(makeNonTrainableFcIdx3, "non_trainable_fc_idx3",
+ ModelTestOption::ALL_V2),
}),
[](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {
return std::get<1>(info.param);