constexpr int kBwProjectionBiasTensor = 34; // Optional
// Output tensors.
-constexpr int kFwScratchBufferTensor = 0;
-constexpr int kFwOutputStateTensor = 1;
-constexpr int kFwCellStateTensor = 2;
-constexpr int kFwOutputTensor = 3;
+constexpr int kFwOutputStateTensor = 0;
+constexpr int kFwCellStateTensor = 1;
+constexpr int kFwOutputTensor = 2;
+
+constexpr int kBwOutputStateTensor = 3;
+constexpr int kBwCellStateTensor = 4;
+constexpr int kBwOutputTensor = 5;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 2, scratch_tensor_index);
+ return scratch_tensor_index;
+}
-constexpr int kBwScratchBufferTensor = 4;
-constexpr int kBwOutputStateTensor = 5;
-constexpr int kBwCellStateTensor = 6;
-constexpr int kBwOutputTensor = 7;
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckLstmTensorDimensions(
// Resize the output, state and scratch tensors based on the sizes of the input
// tensors. Also check that the size of the input tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 35);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 8);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 6);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
TfLiteTensor* fw_output_state =
GetOutput(context, node, kFwOutputStateTensor);
TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* fw_scratch_buffer =
- GetOutput(context, node, kFwScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
fw_output_size->data[0] = max_time;
fw_output_size->data[1] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output_state,
fw_output_state_size));
- // Resize the scratch buffer tensor.
TfLiteIntArray* fw_cell_size = TfLiteIntArrayCreate(2);
fw_cell_size->data[0] = n_batch;
fw_cell_size->data[1] = n_fw_cell;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, fw_cell_state, fw_cell_size));
+ // Create a scratch buffer tensor.
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(2);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* fw_scratch_buffer =
+ &context->tensors[node->temporaries->data[0]];
+ fw_scratch_buffer->type = input->type;
+ fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
fw_output_state->allocation_type = kTfLiteArenaRwPersistent;
fw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_bw_output, n_bw_cell);
- // Get the pointer to output, state and scratch buffer tensors.
+ // Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
TfLiteTensor* bw_output_state =
GetOutput(context, node, kBwOutputStateTensor);
TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* bw_scratch_buffer =
- GetOutput(context, node, kBwScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
bw_output_size->data[0] = max_time;
bw_output_size->data[1] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output_state,
bw_output_state_size));
- // Resize the scratch buffer tensor.
TfLiteIntArray* bw_cell_size = TfLiteIntArrayCreate(2);
bw_cell_size->data[0] = n_batch;
bw_cell_size->data[1] = n_bw_cell;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, bw_cell_state, bw_cell_size));
+ // Create a scratch buffer tensor.
+ node->temporaries->data[1] = *(scratch_tensor_index) + 1;
+ TfLiteTensor* bw_scratch_buffer =
+ &context->tensors[node->temporaries->data[1]];
+ bw_scratch_buffer->type = input->type;
+ bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
bw_output_state->allocation_type = kTfLiteArenaRwPersistent;
bw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* fw_scratch_buffer =
- GetOutput(context, node, kFwScratchBufferTensor);
+ &context->tensors[node->temporaries->data[0]];
float* fw_input_gate_scratch = nullptr;
float* fw_cell_scratch = nullptr;
float* fw_forget_gate_scratch = nullptr;
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* bw_scratch_buffer =
- GetOutput(context, node, kBwScratchBufferTensor);
+ &context->tensors[node->temporaries->data[1]];
float* bw_input_gate_scratch = nullptr;
float* bw_cell_scratch = nullptr;
float* bw_forget_gate_scratch = nullptr;
} // namespace bidirectional_sequence_lstm
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- bidirectional_sequence_lstm::Prepare,
- bidirectional_sequence_lstm::Eval};
+ static TfLiteRegistration r = {
+ bidirectional_sequence_lstm::Init, bidirectional_sequence_lstm::Free,
+ bidirectional_sequence_lstm::Prepare, bidirectional_sequence_lstm::Eval};
return &r;
}
fw_projection_bias_ = AddNullInput();
}
- fw_scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
fw_output_state_ = AddOutput(TensorType_FLOAT32);
fw_cell_state_ = AddOutput(TensorType_FLOAT32);
fw_output_ = AddOutput(TensorType_FLOAT32);
bw_projection_bias_ = AddNullInput();
}
- bw_scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
bw_output_state_ = AddOutput(TensorType_FLOAT32);
bw_cell_state_ = AddOutput(TensorType_FLOAT32);
bw_output_ = AddOutput(TensorType_FLOAT32);
int fw_output_;
int fw_output_state_;
int fw_cell_state_;
- int fw_scratch_buffer_;
int bw_output_;
int bw_output_state_;
int bw_cell_state_;
- int bw_scratch_buffer_;
int n_batch_;
int n_input_;
constexpr int kProjectionBiasTensor = 17; // Optional
// Output tensors.
-constexpr int kScratchBufferTensor = 0;
-constexpr int kOutputStateTensor = 1;
-constexpr int kCellStateTensor = 2;
-constexpr int kOutputTensor = 3;
+constexpr int kOutputStateTensor = 0;
+constexpr int kCellStateTensor = 1;
+constexpr int kOutputTensor = 2;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 1, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
return kTfLiteOk;
}
-// Resize the output, state and scratch tensors based on the sizes of the input
-// tensors. Also check that the size of the input tensors match each other.
+// Resize the output, state tensors based on the sizes of the input tensors.
+// Allocate a temporary scratch tensor. Also check that the sizes of the input
+// tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
- // Get the pointer to output, state and scratch buffer tensors.
+ // Get the pointer to output, output_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
output_size->data[0] = n_batch;
output_size->data[1] = n_output;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, output_state, output_state_size));
- // Resize the output, state and scratch buffer tensors.
TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
cell_size->data[0] = n_batch;
cell_size->data[1] = n_cell;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, cell_state, cell_size));
+ // Create a scratch buffer tensor.
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(1);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
+ scratch_buffer->type = input->type;
+ scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
output_state->allocation_type = kTfLiteArenaRwPersistent;
cell_state->allocation_type = kTfLiteArenaRwPersistent;
const bool use_peephole = (cell_to_output_weights != nullptr);
// Index the scratch buffers pointers to the global scratch buffer.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
+
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
} // namespace lstm
TfLiteRegistration* Register_LSTM() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- lstm::Prepare, lstm::Eval};
+ static TfLiteRegistration r = {lstm::Init, lstm::Free, lstm::Prepare,
+ lstm::Eval};
return &r;
}
projection_bias_ = AddNullInput();
}
- scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
output_state_ = AddOutput(TensorType_FLOAT32);
cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
int output_;
int output_state_;
int cell_state_;
- int scratch_buffer_;
int n_batch_;
int n_input_;
projection_bias_ = AddNullInput();
}
- scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
output_state_ = AddOutput(TensorType_FLOAT32);
cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
int output_;
int output_state_;
int cell_state_;
- int scratch_buffer_;
int n_batch_;
int n_input_;
constexpr int kProjectionBiasTensor = 17; // Optional
// Output tensors.
-constexpr int kScratchBufferTensor = 0;
-constexpr int kOutputStateTensor = 1;
-constexpr int kCellStateTensor = 2;
-constexpr int kOutputTensor = 3;
+constexpr int kOutputStateTensor = 0;
+constexpr int kCellStateTensor = 1;
+constexpr int kOutputTensor = 2;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 1, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
return kTfLiteOk;
}
-// Resize the output, state and scratch tensors based on the sizes of the input
-// tensors. Also check that the size of the input tensors match each other.
+// Resize the output and state tensors based on the sizes of the input tensors.
+// Allocate a temprory scratch tensor. Also check that the sizes of the input
+// tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
- // Get the pointer to output, state and scratch buffer tensors.
+ // Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = max_time;
output_size->data[1] = n_batch;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, output_state, output_state_size));
- // Resize the scratch buffer tensor.
TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
cell_size->data[0] = n_batch;
cell_size->data[1] = n_cell;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, cell_state, cell_size));
+ // Create a scratch buffer tensor.
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(1);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
+ scratch_buffer->type = input->type;
+ scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
output_state->allocation_type = kTfLiteArenaRwPersistent;
cell_state->allocation_type = kTfLiteArenaRwPersistent;
const bool use_peephole = (cell_to_output_weights != nullptr);
// Index the scratch buffers pointers to the global scratch buffer.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
} // namespace unidirectional_sequence_lstm
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ static TfLiteRegistration r = {unidirectional_sequence_lstm::Init,
+ unidirectional_sequence_lstm::Free,
unidirectional_sequence_lstm::Prepare,
unidirectional_sequence_lstm::Eval};
return &r;
projection_bias_ = AddNullInput();
}
- scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
output_state_ = AddOutput(TensorType_FLOAT32);
cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
int output_;
int output_state_;
int cell_state_;
- int scratch_buffer_;
int n_batch_;
int n_input_;
ASSERT_TRUE(ConvertCsvData(
"speech_speakerid_model.tflite", "speech_speakerid_model_in.csv",
"speech_speakerid_model_out.csv", /*input_tensor=*/"0",
- /*output_tensor=*/"66",
- /*persistent_tensors=*/"19,20,40,41,61,62",
+ /*output_tensor=*/"63",
+ /*persistent_tensors=*/"18,19,38,39,58,59",
/*sequence_size=*/80, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
ASSERT_TRUE(
ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv",
"speech_asr_am_model_out.csv", /*input_tensor=*/"0",
- /*output_tensor=*/"109",
- /*persistent_tensors=*/"19,20,40,41,61,62,82,83,103,104",
+ /*output_tensor=*/"104",
+ /*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99",
/*sequence_size=*/320, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
ASSERT_TRUE(ConvertCsvData(
"speech_endpointer_model.tflite", "speech_endpointer_model_in.csv",
"speech_endpointer_model_out.csv", /*input_tensor=*/"0",
- /*output_tensor=*/"58",
- /*persistent_tensors=*/"28,29,49,50",
+ /*output_tensor=*/"56",
+ /*persistent_tensors=*/"27,28,47,48",
/*sequence_size=*/320, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite",
"speech_tts_model_in.csv",
"speech_tts_model_out.csv", /*input_tensor=*/"0",
- /*output_tensor=*/"74",
- /*persistent_tensors=*/"25,26,46,47,67,68,73",
+ /*output_tensor=*/"71",
+ /*persistent_tensors=*/"24,25,44,45,64,65,70",
/*sequence_size=*/334, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
load_model: "speech_asr_lm_model.tflite"
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 3
input: "63982"
input: "63981"
output: "-0.314846"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 6
input: "63982"
input: "3082"
output: "-3.63721"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 8
input: "63982"
input: "18965"
output: "-6.93985"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 13
input: "63982"
input: "63981"
output: "-3.82091"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 19
input: "63982"
input: "63981"
output: "-0.677399"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 26
input: "63982"
input: "63981"
output: "0.415889"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 30
input: "63982"
input: "51923"
output: "-14.1147"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 34
input: "63982"
input: "16318"
output: "-1.54815"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 36
input: "63982"
input: "28303"
output: "-14.0947"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 38
input: "63982"
lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] = prev_activ_input;
lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state_input;
- // Reorder LstmCell's 4 outputs.
+ // Reorder LstmCell's 3 outputs.
lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS);
lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] =
src_op->outputs[kOutputTensor];
lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] =
src_op->outputs[kCellStateTensor];
- lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] =
- src_op->outputs[kScratchBufferTensor];
lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] =
src_op->outputs[kOutputStateTensor];
+ // Create a new temp array for the fourth output.
+ const string& concat_temp_array_name =
+ AvailableArrayName(*model, base_name + "concat_temp");
+ model->GetOrCreateArray(concat_temp_array_name);
+ lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
// Add the op into model.
model->operators.emplace(op_it, std::move(lstm_cell_op));
CreateOptionalArray(model, &(lstm_cell_op->inputs[kProjectionBiasTensor]),
base_name + "proj_bias");
- // Reorder LstmCell's outputs.
- lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS);
- lstm_cell_op->outputs[kScratchBufferTensor] =
- curr_op->outputs[LstmCellOperator::CONCAT_TEMP];
+ // Reorder and resize LstmCell's outputs.
+ lstm_cell_op->outputs.resize(
+ ExtendedLstmCellOutputs::kExtendedLstmOutputCount);
lstm_cell_op->outputs[kOutputStateTensor] =
curr_op->outputs[LstmCellOperator::ACTIV_TEMP];
lstm_cell_op->outputs[kCellStateTensor] =
};
enum ExtendedLstmCellOutputs {
- kScratchBufferTensor = 0,
- kOutputStateTensor = 1,
- kCellStateTensor = 2,
- kOutputTensor = 3
+ kOutputStateTensor = 0,
+ kCellStateTensor = 1,
+ kOutputTensor = 2,
+ kExtendedLstmOutputCount = 3
};
// Create optional array used for optional tensor in ExtendedLstmCell inputs.