kTfLiteBuiltinLogSoftmax = 50,
kTfLiteBuiltinDelegate = 51,
kTfLiteBuiltinBidirectionalSequenceLstm = 52,
+ kTfLiteBuiltinCast = 53,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
"batch_to_space_nd.cc",
"bidirectional_sequence_lstm.cc",
"bidirectional_sequence_rnn.cc",
+ "cast.cc",
"concatenation.cc",
"conv.cc",
"depthwise_conv.cc",
)
tf_cc_test(
+ name = "cast_test",
+ size = "small",
+ srcs = ["cast_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
name = "concatenation_test",
size = "small",
srcs = ["concatenation_test.cc"],
--- /dev/null
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <algorithm>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace cast {
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+template <typename FromT, typename ToT>
+void copyCast(const FromT* in, ToT* out, int num_elements) {
+ std::transform(in, in + num_elements, out,
+ [](FromT a) { return static_cast<ToT>(a); });
+}
+
+template <typename FromT>
+TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out,
+ int num_elements) {
+ switch (out->type) {
+ case kTfLiteInt64:
+ copyCast(in, out->data.i64, num_elements);
+ break;
+ case kTfLiteInt32:
+ copyCast(in, out->data.i32, num_elements);
+ break;
+ case kTfLiteUInt8:
+ copyCast(in, out->data.uint8, num_elements);
+ break;
+ case kTfLiteFloat32:
+ copyCast(in, out->data.f, num_elements);
+ break;
+ default:
+ // Unsupported type.
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ const int num_elements = NumElements(input);
+ TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output));
+ switch (input->type) {
+ case kTfLiteInt64:
+ return copyToTensor(input->data.i64, output, num_elements);
+ case kTfLiteInt32:
+ return copyToTensor(input->data.i32, output, num_elements);
+ case kTfLiteUInt8:
+ return copyToTensor(input->data.uint8, output, num_elements);
+ case kTfLiteFloat32:
+ return copyToTensor(input->data.f, output, num_elements);
+ default:
+ // Unsupported type.
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+} // namespace cast
+
+TfLiteRegistration* Register_CAST() {
+ static TfLiteRegistration r = {nullptr, nullptr, cast::Prepare, cast::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
--- /dev/null
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class CastOpModel : public SingleOpModel {
+ public:
+ CastOpModel(const TensorData& input, const TensorData& output) {
+ input_ = AddInput(input);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_CAST, BuiltinOptions_CastOptions,
+ CreateCastOptions(builder_).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ int input() const { return input_; }
+ int output() const { return output_; }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+TEST(CastOpModel, CastIntToFloat) {
+ CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
+ m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
+}
+
+TEST(CastOpModel, CastFloatToInt) {
+ CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT32, {3, 2}});
+ m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<int>(m.output()),
+ ElementsAreArray({100, 20, 3, 0, 0, 1}));
+}
+
+} // namespace
+} // namespace tflite
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
TfLiteRegistration* Register_EXP();
TfLiteRegistration* Register_TOPK_V2();
TfLiteRegistration* Register_LOG_SOFTMAX();
+TfLiteRegistration* Register_CAST();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2());
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX());
+ AddBuiltin(BuiltinOperator_CAST, Register_CAST());
}
TfLiteRegistration* BuiltinOpResolver::FindOp(
case BuiltinOperator_EXP:
case BuiltinOperator_TOPK_V2:
case BuiltinOperator_LOG_SOFTMAX:
+ case BuiltinOperator_CAST:
break;
case BuiltinOperator_LSH_PROJECTION: {
TfLiteLSHProjectionParams* params =
case tflite::BuiltinOperator_EXP:
case tflite::BuiltinOperator_LOG_SOFTMAX:
case tflite::BuiltinOperator_DELEGATE:
+ case tflite::BuiltinOperator_CAST:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
// WARNING: Experimental interface, subject to change
DELEGATE = 51,
BIDIRECTIONAL_SEQUENCE_LSTM = 52,
+ CAST = 53,
}
// Options for the builtin operators.
TopKV2Options,
SplitOptions,
LogSoftmaxOptions,
+ CastOptions,
}
enum Padding : byte { SAME, VALID }
table LogSoftmaxOptions {
}
+table CastOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
struct LogSoftmaxOptions;
struct LogSoftmaxOptionsT;
+struct CastOptions;
+struct CastOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
BuiltinOperator_LOG_SOFTMAX = 50,
BuiltinOperator_DELEGATE = 51,
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM = 52,
+ BuiltinOperator_CAST = 53,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM
+ BuiltinOperator_MAX = BuiltinOperator_CAST
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[50] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[51] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
BuiltinOperator_SPLIT,
BuiltinOperator_LOG_SOFTMAX,
BuiltinOperator_DELEGATE,
- BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM
+ BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
+ BuiltinOperator_CAST
};
return values;
}
"LOG_SOFTMAX",
"DELEGATE",
"BIDIRECTIONAL_SEQUENCE_LSTM",
+ "CAST",
nullptr
};
return names;
BuiltinOptions_TopKV2Options = 34,
BuiltinOptions_SplitOptions = 35,
BuiltinOptions_LogSoftmaxOptions = 36,
+ BuiltinOptions_CastOptions = 37,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_LogSoftmaxOptions
+ BuiltinOptions_MAX = BuiltinOptions_CastOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[37] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[38] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
BuiltinOptions_ExpOptions,
BuiltinOptions_TopKV2Options,
BuiltinOptions_SplitOptions,
- BuiltinOptions_LogSoftmaxOptions
+ BuiltinOptions_LogSoftmaxOptions,
+ BuiltinOptions_CastOptions
};
return values;
}
"TopKV2Options",
"SplitOptions",
"LogSoftmaxOptions",
+ "CastOptions",
nullptr
};
return names;
static const BuiltinOptions enum_value = BuiltinOptions_LogSoftmaxOptions;
};
+template<> struct BuiltinOptionsTraits<CastOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_CastOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
return type == BuiltinOptions_LogSoftmaxOptions ?
reinterpret_cast<const LogSoftmaxOptionsT *>(value) : nullptr;
}
+ CastOptionsT *AsCastOptions() {
+ return type == BuiltinOptions_CastOptions ?
+ reinterpret_cast<CastOptionsT *>(value) : nullptr;
+ }
+ const CastOptionsT *AsCastOptions() const {
+ return type == BuiltinOptions_CastOptions ?
+ reinterpret_cast<const CastOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
flatbuffers::Offset<LogSoftmaxOptions> CreateLogSoftmaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogSoftmaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct CastOptionsT : public flatbuffers::NativeTable {
+ typedef CastOptions TableType;
+ CastOptionsT() {
+ }
+};
+
+struct CastOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef CastOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ CastOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(CastOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<CastOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct CastOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit CastOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CastOptionsBuilder &operator=(const CastOptionsBuilder &);
+ flatbuffers::Offset<CastOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<CastOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<CastOptions> CreateCastOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ CastOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<CastOptions> CreateCastOptions(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
const LogSoftmaxOptions *builtin_options_as_LogSoftmaxOptions() const {
return builtin_options_type() == BuiltinOptions_LogSoftmaxOptions ? static_cast<const LogSoftmaxOptions *>(builtin_options()) : nullptr;
}
+ const CastOptions *builtin_options_as_CastOptions() const {
+ return builtin_options_type() == BuiltinOptions_CastOptions ? static_cast<const CastOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
return builtin_options_as_LogSoftmaxOptions();
}
+template<> inline const CastOptions *Operator::builtin_options_as<CastOptions>() const {
+ return builtin_options_as_CastOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
_fbb);
}
+inline CastOptionsT *CastOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new CastOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void CastOptions::UnPackTo(CastOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<CastOptions> CastOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateCastOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<CastOptions> CreateCastOptions(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CastOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateCastOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
auto ptr = reinterpret_cast<const LogSoftmaxOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_CastOptions: {
+ auto ptr = reinterpret_cast<const CastOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
auto ptr = reinterpret_cast<const LogSoftmaxOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_CastOptions: {
+ auto ptr = reinterpret_cast<const CastOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
auto ptr = reinterpret_cast<const LogSoftmaxOptionsT *>(value);
return CreateLogSoftmaxOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_CastOptions: {
+ auto ptr = reinterpret_cast<const CastOptionsT *>(value);
+ return CreateCastOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
value = new LogSoftmaxOptionsT(*reinterpret_cast<LogSoftmaxOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_CastOptions: {
+ value = new CastOptionsT(*reinterpret_cast<CastOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
delete ptr;
break;
}
+ case BuiltinOptions_CastOptions: {
+ auto ptr = reinterpret_cast<CastOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
return CopyBuffer<ArrayDataType::kFloat>(array, builder);
case ArrayDataType::kInt32:
return CopyBuffer<ArrayDataType::kInt32>(array, builder);
+ case ArrayDataType::kInt64:
+ return CopyBuffer<ArrayDataType::kInt64>(array, builder);
case ArrayDataType::kString:
return CopyBuffer<ArrayDataType::kString>(array, builder);
case ArrayDataType::kUint8: