auto* c2 = rhs;
TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
- CreateBinaryHlo(HloOpcode::kAdd, c1, c2));
+ MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
return ReplaceWithNewInstruction(
add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd,
lhs->mutable_operand(0),
// (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
if (lhs->opcode() == HloOpcode::kDivide &&
rhs->opcode() == HloOpcode::kDivide) {
- TF_ASSIGN_OR_RETURN(
- auto a_times_d,
- CreateBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(0),
- rhs->mutable_operand(1)));
- TF_ASSIGN_OR_RETURN(
- auto b_times_c,
- CreateBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1),
- rhs->mutable_operand(0)));
- TF_ASSIGN_OR_RETURN(auto new_divide, CreateBinaryHlo(HloOpcode::kDivide,
- a_times_d, b_times_c));
+ TF_ASSIGN_OR_RETURN(auto a_times_d, MakeBinaryHlo(HloOpcode::kMultiply,
+ lhs->mutable_operand(0),
+ rhs->mutable_operand(1)));
+ TF_ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply,
+ lhs->mutable_operand(1),
+ rhs->mutable_operand(0)));
+ TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide,
+ a_times_d, b_times_c));
return ReplaceInstruction(divide, new_divide);
}
if (lhs->opcode() == HloOpcode::kDivide) {
TF_ASSIGN_OR_RETURN(
auto b_times_c,
- CreateBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1), rhs));
+ MakeBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1), rhs));
return ReplaceWithNewInstruction(
divide,
HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide,
// A / (B / C) => (A*C) / B
if (rhs->opcode() == HloOpcode::kDivide) {
- TF_ASSIGN_OR_RETURN(
- auto a_times_c,
- CreateBinaryHlo(HloOpcode::kMultiply, lhs, rhs->mutable_operand(1)));
+ TF_ASSIGN_OR_RETURN(auto a_times_c, MakeBinaryHlo(HloOpcode::kMultiply, lhs,
+ rhs->mutable_operand(1)));
return ReplaceWithNewInstruction(
divide,
HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide,
}
TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad,
- CreatePadHlo(pad->mutable_operand(0),
- pad->mutable_operand(1), nonzero_padding));
+ MakePadHlo(pad->mutable_operand(0),
+ pad->mutable_operand(1), nonzero_padding));
// Copy the layout from the original pad instructions. The new pad and the
// slice instruction should all have the same layout.
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
TF_ASSIGN_OR_RETURN(
HloInstruction * slice,
- CreateSliceHlo(nonzero_pad, start_indices, end_indices, strides));
+ MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides));
// Verify that the slice shape matches the pad shape.
TF_RET_CHECK(ShapeUtil::Compatible(slice->shape(), pad->shape()));
}
}
permutation.push_back(index_vector_dim);
- return CreateTransposeHlo(gather_indices, permutation);
+ return MakeTransposeHlo(gather_indices, permutation);
}
// If the gather_indices holds scalar indices (i.e. gather_indices has rank N
dim_numbers.gather_dims_to_operand_dims_size()) {
TF_ASSIGN_OR_RETURN(
HloInstruction * component_to_concat,
- CreateSliceHlo(
- index_vector, /*start_indices=*/{index_vector_dim_index},
- /*limit_indices=*/{index_vector_dim_index + 1}, /*strides=*/{1}));
+ MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
+ /*limit_indices=*/{index_vector_dim_index + 1},
+ /*strides=*/{1}));
expanded_index_components.push_back(component_to_concat);
} else {
expanded_index_components.push_back(zero);
}
}
- return CreateConcatHlo(expanded_index_components, /*dimension=*/0);
+ return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
}
// This generates the body of the while that implements the main data movement
TF_ASSIGN_OR_RETURN(
HloInstruction * induction_var_as_vector,
- CreateBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
- /*result_shape_bounds=*/{1}));
+ MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
+ /*result_shape_bounds=*/{1}));
TF_ASSIGN_OR_RETURN(
HloInstruction * index_into_gather_indices,
TF_ASSIGN_OR_RETURN(
HloInstruction * index_vector_2d,
- CreateDynamicSliceHlo(gather_indices, index_into_gather_indices,
- {1, index_vector_size}));
+ MakeDynamicSliceHlo(gather_indices, index_into_gather_indices,
+ {1, index_vector_size}));
TF_ASSIGN_OR_RETURN(HloInstruction * index_vector,
ElideDegenerateDims(index_vector_2d, {0}));
operand->shape().dimensions_size()));
TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice,
- CreateDynamicSliceHlo(operand, gathered_slice_start,
- gather.gather_window_bounds()));
+ MakeDynamicSliceHlo(operand, gathered_slice_start,
+ gather.gather_window_bounds()));
TF_ASSIGN_OR_RETURN(
HloInstruction * gathered_slice_for_update,
TF_ASSIGN_OR_RETURN(
HloInstruction * updated_accumulator,
- CreateDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update,
- index_vector_into_accumulator));
+ MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update,
+ index_vector_into_accumulator));
// New loop state -- only the accumulator has changed. The
// WhileUtil::MakeCountedLoop functions takes care of the induction variable
}
}
- return CreateTransposeHlo(accumulator, permutation);
+ return MakeTransposeHlo(accumulator, permutation);
}
// High Level Algorithm
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
MakeUnique<Literal>(Literal::Zero(element_type))));
- input = CreatePadHlo(input, padding, padding_config).ValueOrDie();
+ input = MakePadHlo(input, padding, padding_config).ValueOrDie();
}
if (window_util::HasNegativePadding(conv_window)) {
std::max<int64>(0LL, -conv_window.dimensions(i).padding_high());
}
- input = CreateSliceHlo(input, start_indices, limit_indices, strides)
- .ValueOrDie();
+ input =
+ MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie();
}
return input;
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
MakeUnique<Literal>(Literal::Zero(element_type))));
- return CreatePadHlo(kernel, padding, padding_config).ValueOrDie();
+ return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
}
} // namespace
computation->AddInstruction(HloInstruction::CreateConstant(
MakeUnique<Literal>(Literal::Zero(input->shape().element_type()))));
HloInstruction* padded_input =
- CreatePadHlo(input, padding, input_padding_config).ValueOrDie();
+ MakePadHlo(input, padding, input_padding_config).ValueOrDie();
// The shape of the backward_conv CustomCall is a tuple (conv_result,
// scratch_buffer). Extract out the shape of conv_result.
using tensorflow::gtl::ArraySlice;
using tensorflow::strings::StrCat;
-StatusOr<HloInstruction*> CreateBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
- HloInstruction* rhs) {
+StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
+ HloInstruction* rhs) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
TF_ASSIGN_OR_RETURN(Shape binary_op_shape,
HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs));
}
-StatusOr<HloInstruction*> CreatePadHlo(HloInstruction* operand,
- HloInstruction* padding_value,
- const PaddingConfig& padding_config) {
+StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
+ HloInstruction* padding_value,
+ const PaddingConfig& padding_config) {
HloComputation* computation = operand->parent();
CHECK_EQ(computation, padding_value->parent());
TF_ASSIGN_OR_RETURN(
pad_shape, operand, padding_value, padding_config));
}
-StatusOr<HloInstruction*> CreateSliceHlo(HloInstruction* operand,
- ArraySlice<int64> start_indices,
- ArraySlice<int64> limit_indices,
- ArraySlice<int64> strides) {
+StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
+ ArraySlice<int64> start_indices,
+ ArraySlice<int64> limit_indices,
+ ArraySlice<int64> strides) {
HloComputation* computation = operand->parent();
TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape(
operand->shape(), start_indices,
slice_shape, operand, start_indices, limit_indices, strides));
}
-StatusOr<HloInstruction*> CreateConvolveHlo(
+StatusOr<HloInstruction*> MakeConvolveHlo(
HloInstruction* lhs, HloInstruction* rhs, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers) {
HloComputation* computation = lhs->parent();
convolve_shape, lhs, rhs, window, dimension_numbers));
}
-StatusOr<HloInstruction*> CreateTransposeHlo(HloInstruction* operand,
- ArraySlice<int64> dimensions) {
+StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
+ ArraySlice<int64> dimensions) {
HloComputation* computation = operand->parent();
TF_ASSIGN_OR_RETURN(
Shape transpose_shape,
HloInstruction::CreateTranspose(transpose_shape, operand, dimensions));
}
-StatusOr<HloInstruction*> CreateReshapeHlo(const Shape& result_shape,
- HloInstruction* operand) {
+StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
+ HloInstruction* operand) {
HloComputation* computation = operand->parent();
return computation->AddInstruction(
HloInstruction::CreateReshape(result_shape, operand));
}
-StatusOr<HloInstruction*> CreateReshapeHlo(
+StatusOr<HloInstruction*> MakeReshapeHlo(
ArraySlice<int64> result_shape_dim_bounds, HloInstruction* operand) {
Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
result_shape_dim_bounds);
- return CreateReshapeHlo(new_shape, operand);
+ return MakeReshapeHlo(new_shape, operand);
}
-StatusOr<HloInstruction*> CreateDynamicSliceHlo(HloInstruction* operand,
- HloInstruction* start_indices,
- ArraySlice<int64> slice_sizes) {
+StatusOr<HloInstruction*> MakeDynamicSliceHlo(HloInstruction* operand,
+ HloInstruction* start_indices,
+ ArraySlice<int64> slice_sizes) {
HloComputation* computation = operand->parent();
CHECK_EQ(computation, start_indices->parent());
TF_ASSIGN_OR_RETURN(
dynamic_slice_shape, operand, start_indices, slice_sizes));
}
-StatusOr<HloInstruction*> CreateDynamicUpdateSliceHlo(
+StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
HloInstruction* operand, HloInstruction* update,
HloInstruction* start_indices) {
HloComputation* computation = operand->parent();
dynamic_update_slice_shape, operand, update, start_indices));
}
-StatusOr<HloInstruction*> CreateBroadcastHlo(
+StatusOr<HloInstruction*> MakeBroadcastHlo(
HloInstruction* operand, ArraySlice<int64> broadcast_dimensions,
ArraySlice<int64> result_shape_bounds) {
HloComputation* computation = operand->parent();
broadcast_shape, operand, broadcast_dimensions));
}
-StatusOr<HloInstruction*> CreateGetTupleElementHlo(HloInstruction* operand,
- int64 index) {
+StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
+ int64 index) {
HloComputation* computation = operand->parent();
TF_ASSIGN_OR_RETURN(
HloInstruction::CreateGetTupleElement(gte_shape, operand, index));
}
-StatusOr<HloInstruction*> CreateConcatHlo(ArraySlice<HloInstruction*> operands,
- int64 dimension) {
+StatusOr<HloInstruction*> MakeConcatHlo(ArraySlice<HloInstruction*> operands,
+ int64 dimension) {
CHECK_GT(operands.size(), 0);
HloComputation* computation = operands[0]->parent();
Shape output_shape =
ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims);
- return CreateReshapeHlo(output_shape, operand);
+ return MakeReshapeHlo(output_shape, operand);
}
StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
std::back_inserter(expanded_shape_dim_bounds));
Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
expanded_shape_dim_bounds);
- return CreateReshapeHlo(new_shape, operand);
+ return MakeReshapeHlo(new_shape, operand);
}
StatusOr<HloInstruction*> ExpandLastDimIntoNDims(
c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
expanded_shape_dim_bounds);
- return CreateReshapeHlo(new_shape, operand);
+ return MakeReshapeHlo(new_shape, operand);
}
StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
c_reverse(new_shape_dim_bounds);
Shape output_shape =
ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds);
- return CreateReshapeHlo(output_shape, operand);
+ return MakeReshapeHlo(output_shape, operand);
}
StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
MakeUnique<Literal>(Literal::Zero(operand->shape().element_type()))));
- return CreatePadHlo(operand, zero, padding_config);
+ return MakePadHlo(operand, zero, padding_config);
}
StatusOr<HloInstruction*> BroadcastZeros(
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
MakeUnique<Literal>(Literal::Zero(element_type))));
- return CreateBroadcastHlo(zero, /*broadcast_dimensions=*/{},
- /*result_shape_bounds=*/broadcast_dimensions);
+ return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
+ /*result_shape_bounds=*/broadcast_dimensions);
}
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
// Creates a binary HLO instruction and adds it to the computation containing
// `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
-StatusOr<HloInstruction*> CreateBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
- HloInstruction* rhs);
+StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
+ HloInstruction* rhs);
// Creates a pad HLO instruction and adds it to the computation containing
// `operand` and `padding_value` (`operand` and `padding_value` must be in the
// same computation).
-StatusOr<HloInstruction*> CreatePadHlo(HloInstruction* operand,
- HloInstruction* padding_value,
- const PaddingConfig& padding_config);
+StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
+ HloInstruction* padding_value,
+ const PaddingConfig& padding_config);
// Creates a slice HLO instruction and adds it to the computation containing
// `operand`.
-StatusOr<HloInstruction*> CreateSliceHlo(
+StatusOr<HloInstruction*> MakeSliceHlo(
HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides);
// Creates a convolution HLO instruction and adds it to the computation
// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
-StatusOr<HloInstruction*> CreateConvolveHlo(
+StatusOr<HloInstruction*> MakeConvolveHlo(
HloInstruction* lhs, HloInstruction* rhs, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers);
// Creates a transpose HLO instruction and adds it to the computation containing
// `operand`.
-StatusOr<HloInstruction*> CreateTransposeHlo(
+StatusOr<HloInstruction*> MakeTransposeHlo(
HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions);
// Creates a reshape HLO instruction and adds it to the computation containing
// `operand`.
-StatusOr<HloInstruction*> CreateReshapeHlo(const Shape& result_shape,
- HloInstruction* operand);
+StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
+ HloInstruction* operand);
-StatusOr<HloInstruction*> CreateReshapeHlo(
+StatusOr<HloInstruction*> MakeReshapeHlo(
tensorflow::gtl::ArraySlice<int64> result_shape_dim_bounds,
HloInstruction* operand);
// Creates a dynamic-slice HLO instruction and adds it to the computation
// containing `operand` and `start_indices` (`operand` and `start_indices` must
// be in the same computation).
-StatusOr<HloInstruction*> CreateDynamicSliceHlo(
+StatusOr<HloInstruction*> MakeDynamicSliceHlo(
HloInstruction* operand, HloInstruction* start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes);
// Creates a dynamic-update-slice HLO instruction and adds it to the computation
// containing `operand`, `update` and `start_indices` (`operand`, `update` and
// `start_indices` must be in the same computation).
-StatusOr<HloInstruction*> CreateDynamicUpdateSliceHlo(
+StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
HloInstruction* operand, HloInstruction* update,
HloInstruction* start_indices);
// Creates a broadcast HLO instruction and adds it to the computation containing
// `operand`.
-StatusOr<HloInstruction*> CreateBroadcastHlo(
+StatusOr<HloInstruction*> MakeBroadcastHlo(
HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions,
tensorflow::gtl::ArraySlice<int64> result_shape_bounds);
// Creates a GetTupleElement HLO instruction and adds it to the computation
// containing `operand`.
-StatusOr<HloInstruction*> CreateGetTupleElementHlo(HloInstruction* operand,
- int64 index);
+StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
+ int64 index);
// Creates a Concatenate HLO instruction and adds it to the computation
// containing `operands` (`operands` must be non-empty and every element must be
// contained in the same computation).
-StatusOr<HloInstruction*> CreateConcatHlo(
+StatusOr<HloInstruction*> MakeConcatHlo(
tensorflow::gtl::ArraySlice<HloInstruction*> operands, int64 dimension);
// -----------------------------------------------------------------------------
HloInstruction* param = cond_computation->parameter_instruction(0);
TF_ASSIGN_OR_RETURN(HloInstruction * counter,
- CreateGetTupleElementHlo(param, 0));
+ MakeGetTupleElementHlo(param, 0));
TF_ASSIGN_OR_RETURN(
HloInstruction * compare,
- CreateBinaryHlo(HloOpcode::kLt, counter, trip_count_constant));
+ MakeBinaryHlo(HloOpcode::kLt, counter, trip_count_constant));
cond_computation->set_root_instruction(compare);
return std::move(cond_computation);
}
HloInstruction* param = body_computation->parameter_instruction(0);
TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
- CreateGetTupleElementHlo(param, 0));
+ MakeGetTupleElementHlo(param, 0));
TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar,
- CreateBinaryHlo(HloOpcode::kAdd, indvar, one));
+ MakeBinaryHlo(HloOpcode::kAdd, indvar, one));
std::vector<HloInstruction*> loop_body_generator_args;
for (int64 i = 1, e = loop_state_shape.tuple_shapes_size(); i < e; i++) {
TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element,
- CreateGetTupleElementHlo(param, i));
+ MakeGetTupleElementHlo(param, i));
loop_body_generator_args.push_back(tuple_element);
}
TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> next_state,
std::vector<HloInstruction*> result;
for (int64 i = 0, e = init_values.size(); i < e; i++) {
TF_ASSIGN_OR_RETURN(HloInstruction * user_state,
- CreateGetTupleElementHlo(while_instr, i + 1));
+ MakeGetTupleElementHlo(while_instr, i + 1));
result.push_back(user_state);
}
return result;