}
}
- bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
+ bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
+ gtl::ArraySlice<tensorflow::DataType> dtypes);
void Watch(int64 tensor_id);
void RecordOperation(const string& op_type,
gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
+ gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter);
// Template instantiations here
+inline bool IsDtypeTrainable(DataType dtype) {
+ switch (dtype) {
+ case DT_HALF:
+ case DT_BFLOAT16:
+ case DT_FLOAT:
+ case DT_DOUBLE:
+ case DT_COMPLEX64:
+ case DT_COMPLEX128:
+ case DT_RESOURCE:
+ case DT_VARIANT:
+ return true;
+ default:
+ return false;
+ }
+}
+
template <typename Gradient, typename BackwardFunction>
bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
- gtl::ArraySlice<int64> tensor_ids) {
- for (int64 i : tensor_ids) {
- if (tensor_tape_.find(i) != tensor_tape_.end()) {
- return true;
+ gtl::ArraySlice<int64> tensor_ids,
+ gtl::ArraySlice<tensorflow::DataType> dtypes) {
+ CHECK_EQ(tensor_ids.size(), dtypes.size());
+ for (int i = 0; i < tensor_ids.size(); ++i) {
+ if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
+ return IsDtypeTrainable(dtypes[i]);
}
}
return false;
template <typename Gradient, typename BackwardFunction>
void GradientTape<Gradient, BackwardFunction>::RecordOperation(
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
- gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
+ gtl::ArraySlice<int64> input_tensor_id,
+ gtl::ArraySlice<tensorflow::DataType> input_dtypes,
+ BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter) {
- if (!ShouldRecord(input_tensor_id)) {
+ if (!ShouldRecord(input_tensor_id, input_dtypes)) {
backward_function_deleter();
return;
}
return id;
}
+static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
+ if (EagerTensor_CheckExact(tensor)) {
+ return EagerTensor_dtype(tensor);
+ }
+ PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
+ if (dtype_field == nullptr) {
+ return tensorflow::DT_INVALID;
+ }
+ PyObject* enum_field = PyObject_GetAttrString(dtype_field, "_type_enum");
+ Py_DECREF(dtype_field);
+ if (dtype_field == nullptr) {
+ return tensorflow::DT_INVALID;
+ }
+ tensorflow::int64 id = MakeInt(enum_field);
+ Py_DECREF(enum_field);
+ return static_cast<tensorflow::DataType>(id);
+}
+
class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyObject> {
public:
// TODO(apassos) consider not building a list and changing the API to check
// each tensor individually.
std::vector<tensorflow::int64> tensor_ids;
+ std::vector<tensorflow::DataType> dtypes;
tensor_ids.reserve(len);
+ dtypes.reserve(len);
for (int i = 0; i < len; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
tensor_ids.push_back(FastTensorId(item));
+ dtypes.push_back(FastTensorDtype(item));
}
Py_DECREF(seq);
auto tape_set = *tape_set_ptr;
for (TFE_Py_Tape* tape : tape_set) {
- if (tape->tape->ShouldRecord(tensor_ids)) {
+ if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
Py_RETURN_TRUE;
}
}
}
namespace {
-void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
- const std::vector<tensorflow::int64>& input_ids,
- PyObject* backward_function) {
+std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
+ PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
+ if (seq == nullptr) {
+ return {};
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ std::vector<tensorflow::DataType> list;
+ list.reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
+ list.push_back(FastTensorDtype(tensor));
+ }
+ Py_DECREF(seq);
+ return list;
+}
+
+void TapeSetRecordOperation(
+ PyObject* op_type, PyObject* output_tensors,
+ const std::vector<tensorflow::int64>& input_ids,
+ const std::vector<tensorflow::DataType>& input_dtypes,
+ PyObject* backward_function) {
std::vector<tensorflow::eager::TapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
for (TFE_Py_Tape* tape : SafeTapeSet()) {
Py_INCREF(backward_function);
tape->tape->RecordOperation(
- op_type_str, output_info, input_ids, backward_function,
+ op_type_str, output_info, input_ids, input_dtypes, backward_function,
[backward_function]() { Py_DECREF(backward_function); });
}
}
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
if (PyErr_Occurred()) return;
- TapeSetRecordOperation(op_type, output_tensors, input_ids, backward_function);
+ std::vector<tensorflow::DataType> input_dtypes =
+ MakeTensorDtypeList(input_tensors);
+ if (PyErr_Occurred()) return;
+ TapeSetRecordOperation(op_type, output_tensors, input_ids, input_dtypes,
+ backward_function);
}
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
PyObject* results, PyObject* name) {
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
if (PyErr_Occurred()) return nullptr;
+ std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
+ if (PyErr_Occurred()) return nullptr;
bool should_record = false;
for (TFE_Py_Tape* tape : SafeTapeSet()) {
- if (tape->tape->ShouldRecord(input_ids)) {
+ if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
should_record = true;
break;
}
Py_DECREF(callback_args);
if (backward_function == nullptr) return nullptr;
- TapeSetRecordOperation(op_name, results, input_ids, backward_function);
+ TapeSetRecordOperation(op_name, results, input_ids, input_dtypes,
+ backward_function);
Py_DECREF(backward_function);