};
// Construct the While loop, extract and reshape the output.
- auto num_indices_value =
- XlaHelpers::IntegerLiteral(builder, index_type, num_indices);
- TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices_value, body_fn,
+ xla::PrimitiveType ptype;
+ TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(index_type, &ptype));
+ TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices, ptype, body_fn,
init, "gather", builder));
*gather_output = builder->Reshape(outputs[2], out_shape.dim_sizes());
return Status::OK();
return std::vector<xla::ComputationDataHandle>{indices, updates, buffer};
};
- xla::ComputationDataHandle num_indices_value =
- IntegerLiteral(builder, indices_shape->element_type(), num_indices);
- TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices_value, body_fn,
- init, "scatter", builder));
+ TF_ASSIGN_OR_RETURN(
+ auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(),
+ body_fn, init, "scatter", builder));
return outputs[2];
}
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
+#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
}
xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
- const xla::ComputationDataHandle& num_iterations,
+ int64 num_iterations, xla::PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
StringPiece name, xla::ComputationBuilder* builder) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> num_iterations_shape,
- builder->GetShape(num_iterations));
- TF_RET_CHECK(xla::ShapeUtil::IsScalar(*num_iterations_shape));
- xla::PrimitiveType num_iterations_type = num_iterations_shape->element_type();
-
auto while_cond_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
xla::ComputationBuilder* cond_builder)
-> xla::StatusOr<xla::ComputationDataHandle> {
- return cond_builder->Lt(values[0], values[1]);
+ return cond_builder->Lt(
+ values[0],
+ IntegerLiteral(cond_builder, num_iterations_type, num_iterations));
};
auto while_body_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
xla::ComputationBuilder* body_builder)
updated_values.push_back(body_builder->Add(
iteration,
body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type))));
- updated_values.push_back(values[1]);
- values.remove_prefix(2);
+ values.remove_prefix(1);
TF_ASSIGN_OR_RETURN(std::vector<xla::ComputationDataHandle> body_outputs,
body_function(iteration, values, body_builder));
updated_values.insert(updated_values.end(), body_outputs.begin(),
};
std::vector<xla::ComputationDataHandle> values;
- values.reserve(initial_values.size() + 2);
+ values.reserve(initial_values.size() + 1);
values.push_back(
builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type)));
- values.push_back(num_iterations);
values.insert(values.end(), initial_values.begin(), initial_values.end());
TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values,
name, builder));
- values.erase(values.begin(), values.begin() + 2);
+ values.erase(values.begin(), values.begin() + 1);
return values;
}