Rename CreateXyzHlo utilities to MakeXyzHlo as discussed on cr/188968478; NFC
authorSanjoy Das <sanjoy@google.com>
Thu, 15 Mar 2018 22:47:09 +0000 (15:47 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Mar 2018 23:01:07 +0000 (16:01 -0700)
The rationale here is that MakeXyzHlo is less likely to be confused with
HloInstruction::CreateXyz and we already have a convention of using a "Make"
prefix for ergonomic factory functions.
PiperOrigin-RevId: 189259036

tensorflow/compiler/xla/service/algebraic_simplifier.cc
tensorflow/compiler/xla/service/gather_expander.cc
tensorflow/compiler/xla/service/gpu/pad_insertion.cc
tensorflow/compiler/xla/service/hlo_creation_utils.cc
tensorflow/compiler/xla/service/hlo_creation_utils.h
tensorflow/compiler/xla/service/while_util.cc

index be7aa30..971c293 100644 (file)
@@ -385,7 +385,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
     auto* c2 = rhs;
 
     TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
-                        CreateBinaryHlo(HloOpcode::kAdd, c1, c2));
+                        MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
     return ReplaceWithNewInstruction(
         add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd,
                                           lhs->mutable_operand(0),
@@ -636,16 +636,14 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
   // (A / B) / (C / D)  =>  (A / B)*(D / C) => (A * D) / (B * C)
   if (lhs->opcode() == HloOpcode::kDivide &&
       rhs->opcode() == HloOpcode::kDivide) {
-    TF_ASSIGN_OR_RETURN(
-        auto a_times_d,
-        CreateBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(0),
-                        rhs->mutable_operand(1)));
-    TF_ASSIGN_OR_RETURN(
-        auto b_times_c,
-        CreateBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1),
-                        rhs->mutable_operand(0)));
-    TF_ASSIGN_OR_RETURN(auto new_divide, CreateBinaryHlo(HloOpcode::kDivide,
-                                                         a_times_d, b_times_c));
+    TF_ASSIGN_OR_RETURN(auto a_times_d, MakeBinaryHlo(HloOpcode::kMultiply,
+                                                      lhs->mutable_operand(0),
+                                                      rhs->mutable_operand(1)));
+    TF_ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply,
+                                                      lhs->mutable_operand(1),
+                                                      rhs->mutable_operand(0)));
+    TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide,
+                                                       a_times_d, b_times_c));
 
     return ReplaceInstruction(divide, new_divide);
   }
@@ -654,7 +652,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
   if (lhs->opcode() == HloOpcode::kDivide) {
     TF_ASSIGN_OR_RETURN(
         auto b_times_c,
-        CreateBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1), rhs));
+        MakeBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1), rhs));
     return ReplaceWithNewInstruction(
         divide,
         HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide,
@@ -663,9 +661,8 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
 
   // A / (B / C) => (A*C) / B
   if (rhs->opcode() == HloOpcode::kDivide) {
-    TF_ASSIGN_OR_RETURN(
-        auto a_times_c,
-        CreateBinaryHlo(HloOpcode::kMultiply, lhs, rhs->mutable_operand(1)));
+    TF_ASSIGN_OR_RETURN(auto a_times_c, MakeBinaryHlo(HloOpcode::kMultiply, lhs,
+                                                      rhs->mutable_operand(1)));
     return ReplaceWithNewInstruction(
         divide,
         HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide,
@@ -1300,8 +1297,8 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
     }
 
     TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad,
-                        CreatePadHlo(pad->mutable_operand(0),
-                                     pad->mutable_operand(1), nonzero_padding));
+                        MakePadHlo(pad->mutable_operand(0),
+                                   pad->mutable_operand(1), nonzero_padding));
     // Copy the layout from the original pad instructions. The new pad and the
     // slice instruction should all have the same layout.
     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
@@ -1329,7 +1326,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
 
     TF_ASSIGN_OR_RETURN(
         HloInstruction * slice,
-        CreateSliceHlo(nonzero_pad, start_indices, end_indices, strides));
+        MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides));
 
     // Verify that the slice shape matches the pad shape.
     TF_RET_CHECK(ShapeUtil::Compatible(slice->shape(), pad->shape()));
index a133d81..58c62d8 100644 (file)
@@ -39,7 +39,7 @@ static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
     }
   }
   permutation.push_back(index_vector_dim);
-  return CreateTransposeHlo(gather_indices, permutation);
+  return MakeTransposeHlo(gather_indices, permutation);
 }
 
 // If the gather_indices holds scalar indices (i.e. gather_indices has rank N
@@ -133,16 +133,16 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
         dim_numbers.gather_dims_to_operand_dims_size()) {
       TF_ASSIGN_OR_RETURN(
           HloInstruction * component_to_concat,
-          CreateSliceHlo(
-              index_vector, /*start_indices=*/{index_vector_dim_index},
-              /*limit_indices=*/{index_vector_dim_index + 1}, /*strides=*/{1}));
+          MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
+                       /*limit_indices=*/{index_vector_dim_index + 1},
+                       /*strides=*/{1}));
       expanded_index_components.push_back(component_to_concat);
     } else {
       expanded_index_components.push_back(zero);
     }
   }
 
-  return CreateConcatHlo(expanded_index_components, /*dimension=*/0);
+  return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
 }
 
 // This generates the body of the while that implements the main data movement
@@ -159,8 +159,8 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
 
   TF_ASSIGN_OR_RETURN(
       HloInstruction * induction_var_as_vector,
-      CreateBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
-                         /*result_shape_bounds=*/{1}));
+      MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
+                       /*result_shape_bounds=*/{1}));
 
   TF_ASSIGN_OR_RETURN(
       HloInstruction * index_into_gather_indices,
@@ -169,8 +169,8 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
 
   TF_ASSIGN_OR_RETURN(
       HloInstruction * index_vector_2d,
-      CreateDynamicSliceHlo(gather_indices, index_into_gather_indices,
-                            {1, index_vector_size}));
+      MakeDynamicSliceHlo(gather_indices, index_into_gather_indices,
+                          {1, index_vector_size}));
 
   TF_ASSIGN_OR_RETURN(HloInstruction * index_vector,
                       ElideDegenerateDims(index_vector_2d, {0}));
@@ -181,8 +181,8 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
                           operand->shape().dimensions_size()));
 
   TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice,
-                      CreateDynamicSliceHlo(operand, gathered_slice_start,
-                                            gather.gather_window_bounds()));
+                      MakeDynamicSliceHlo(operand, gathered_slice_start,
+                                          gather.gather_window_bounds()));
 
   TF_ASSIGN_OR_RETURN(
       HloInstruction * gathered_slice_for_update,
@@ -197,8 +197,8 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
 
   TF_ASSIGN_OR_RETURN(
       HloInstruction * updated_accumulator,
-      CreateDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update,
-                                  index_vector_into_accumulator));
+      MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update,
+                                index_vector_into_accumulator));
 
   // New loop state -- only the accumulator has changed.  The
   // WhileUtil::MakeCountedLoop functions takes care of the induction variable
@@ -250,7 +250,7 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
     }
   }
 
-  return CreateTransposeHlo(accumulator, permutation);
+  return MakeTransposeHlo(accumulator, permutation);
 }
 
 // High Level Algorithm
index fa405b9..7bda4e2 100644 (file)
@@ -69,7 +69,7 @@ HloInstruction* MaybePaddedAndSlicedInput(
     HloInstruction* padding =
         computation->AddInstruction(HloInstruction::CreateConstant(
             MakeUnique<Literal>(Literal::Zero(element_type))));
-    input = CreatePadHlo(input, padding, padding_config).ValueOrDie();
+    input = MakePadHlo(input, padding, padding_config).ValueOrDie();
   }
 
   if (window_util::HasNegativePadding(conv_window)) {
@@ -92,8 +92,8 @@ HloInstruction* MaybePaddedAndSlicedInput(
           std::max<int64>(0LL, -conv_window.dimensions(i).padding_high());
     }
 
-    input = CreateSliceHlo(input, start_indices, limit_indices, strides)
-                .ValueOrDie();
+    input =
+        MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie();
   }
 
   return input;
@@ -126,7 +126,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
   HloInstruction* padding =
       computation->AddInstruction(HloInstruction::CreateConstant(
           MakeUnique<Literal>(Literal::Zero(element_type))));
-  return CreatePadHlo(kernel, padding, padding_config).ValueOrDie();
+  return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
 }
 }  // namespace
 
@@ -238,7 +238,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
       computation->AddInstruction(HloInstruction::CreateConstant(
           MakeUnique<Literal>(Literal::Zero(input->shape().element_type()))));
   HloInstruction* padded_input =
-      CreatePadHlo(input, padding, input_padding_config).ValueOrDie();
+      MakePadHlo(input, padding, input_padding_config).ValueOrDie();
 
   // The shape of the backward_conv CustomCall is a tuple (conv_result,
   // scratch_buffer).  Extract out the shape of conv_result.
index 4585bff..fbe71f8 100644 (file)
@@ -23,8 +23,8 @@ namespace xla {
 using tensorflow::gtl::ArraySlice;
 using tensorflow::strings::StrCat;
 
-StatusOr<HloInstruction*> CreateBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
-                                          HloInstruction* rhs) {
+StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
+                                        HloInstruction* rhs) {
   HloComputation* computation = lhs->parent();
   CHECK_EQ(computation, rhs->parent());
   TF_ASSIGN_OR_RETURN(Shape binary_op_shape,
@@ -33,9 +33,9 @@ StatusOr<HloInstruction*> CreateBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
       HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs));
 }
 
-StatusOr<HloInstruction*> CreatePadHlo(HloInstruction* operand,
-                                       HloInstruction* padding_value,
-                                       const PaddingConfig& padding_config) {
+StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
+                                     HloInstruction* padding_value,
+                                     const PaddingConfig& padding_config) {
   HloComputation* computation = operand->parent();
   CHECK_EQ(computation, padding_value->parent());
   TF_ASSIGN_OR_RETURN(
@@ -46,10 +46,10 @@ StatusOr<HloInstruction*> CreatePadHlo(HloInstruction* operand,
       pad_shape, operand, padding_value, padding_config));
 }
 
-StatusOr<HloInstruction*> CreateSliceHlo(HloInstruction* operand,
-                                         ArraySlice<int64> start_indices,
-                                         ArraySlice<int64> limit_indices,
-                                         ArraySlice<int64> strides) {
+StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
+                                       ArraySlice<int64> start_indices,
+                                       ArraySlice<int64> limit_indices,
+                                       ArraySlice<int64> strides) {
   HloComputation* computation = operand->parent();
   TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape(
                                              operand->shape(), start_indices,
@@ -58,7 +58,7 @@ StatusOr<HloInstruction*> CreateSliceHlo(HloInstruction* operand,
       slice_shape, operand, start_indices, limit_indices, strides));
 }
 
-StatusOr<HloInstruction*> CreateConvolveHlo(
+StatusOr<HloInstruction*> MakeConvolveHlo(
     HloInstruction* lhs, HloInstruction* rhs, const Window& window,
     const ConvolutionDimensionNumbers& dimension_numbers) {
   HloComputation* computation = lhs->parent();
@@ -70,8 +70,8 @@ StatusOr<HloInstruction*> CreateConvolveHlo(
       convolve_shape, lhs, rhs, window, dimension_numbers));
 }
 
-StatusOr<HloInstruction*> CreateTransposeHlo(HloInstruction* operand,
-                                             ArraySlice<int64> dimensions) {
+StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
+                                           ArraySlice<int64> dimensions) {
   HloComputation* computation = operand->parent();
   TF_ASSIGN_OR_RETURN(
       Shape transpose_shape,
@@ -80,23 +80,23 @@ StatusOr<HloInstruction*> CreateTransposeHlo(HloInstruction* operand,
       HloInstruction::CreateTranspose(transpose_shape, operand, dimensions));
 }
 
-StatusOr<HloInstruction*> CreateReshapeHlo(const Shape& result_shape,
-                                           HloInstruction* operand) {
+StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
+                                         HloInstruction* operand) {
   HloComputation* computation = operand->parent();
   return computation->AddInstruction(
       HloInstruction::CreateReshape(result_shape, operand));
 }
 
-StatusOr<HloInstruction*> CreateReshapeHlo(
+StatusOr<HloInstruction*> MakeReshapeHlo(
     ArraySlice<int64> result_shape_dim_bounds, HloInstruction* operand) {
   Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
                                          result_shape_dim_bounds);
-  return CreateReshapeHlo(new_shape, operand);
+  return MakeReshapeHlo(new_shape, operand);
 }
 
-StatusOr<HloInstruction*> CreateDynamicSliceHlo(HloInstruction* operand,
-                                                HloInstruction* start_indices,
-                                                ArraySlice<int64> slice_sizes) {
+StatusOr<HloInstruction*> MakeDynamicSliceHlo(HloInstruction* operand,
+                                              HloInstruction* start_indices,
+                                              ArraySlice<int64> slice_sizes) {
   HloComputation* computation = operand->parent();
   CHECK_EQ(computation, start_indices->parent());
   TF_ASSIGN_OR_RETURN(
@@ -107,7 +107,7 @@ StatusOr<HloInstruction*> CreateDynamicSliceHlo(HloInstruction* operand,
       dynamic_slice_shape, operand, start_indices, slice_sizes));
 }
 
-StatusOr<HloInstruction*> CreateDynamicUpdateSliceHlo(
+StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
     HloInstruction* operand, HloInstruction* update,
     HloInstruction* start_indices) {
   HloComputation* computation = operand->parent();
@@ -121,7 +121,7 @@ StatusOr<HloInstruction*> CreateDynamicUpdateSliceHlo(
       dynamic_update_slice_shape, operand, update, start_indices));
 }
 
-StatusOr<HloInstruction*> CreateBroadcastHlo(
+StatusOr<HloInstruction*> MakeBroadcastHlo(
     HloInstruction* operand, ArraySlice<int64> broadcast_dimensions,
     ArraySlice<int64> result_shape_bounds) {
   HloComputation* computation = operand->parent();
@@ -132,8 +132,8 @@ StatusOr<HloInstruction*> CreateBroadcastHlo(
       broadcast_shape, operand, broadcast_dimensions));
 }
 
-StatusOr<HloInstruction*> CreateGetTupleElementHlo(HloInstruction* operand,
-                                                   int64 index) {
+StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
+                                                 int64 index) {
   HloComputation* computation = operand->parent();
 
   TF_ASSIGN_OR_RETURN(
@@ -143,8 +143,8 @@ StatusOr<HloInstruction*> CreateGetTupleElementHlo(HloInstruction* operand,
       HloInstruction::CreateGetTupleElement(gte_shape, operand, index));
 }
 
-StatusOr<HloInstruction*> CreateConcatHlo(ArraySlice<HloInstruction*> operands,
-                                          int64 dimension) {
+StatusOr<HloInstruction*> MakeConcatHlo(ArraySlice<HloInstruction*> operands,
+                                        int64 dimension) {
   CHECK_GT(operands.size(), 0);
 
   HloComputation* computation = operands[0]->parent();
@@ -181,7 +181,7 @@ StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
   Shape output_shape =
       ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims);
 
-  return CreateReshapeHlo(output_shape, operand);
+  return MakeReshapeHlo(output_shape, operand);
 }
 
 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
@@ -198,7 +198,7 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
             std::back_inserter(expanded_shape_dim_bounds));
   Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
                                          expanded_shape_dim_bounds);
-  return CreateReshapeHlo(new_shape, operand);
+  return MakeReshapeHlo(new_shape, operand);
 }
 
 StatusOr<HloInstruction*> ExpandLastDimIntoNDims(
@@ -216,7 +216,7 @@ StatusOr<HloInstruction*> ExpandLastDimIntoNDims(
   c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
   Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
                                          expanded_shape_dim_bounds);
-  return CreateReshapeHlo(new_shape, operand);
+  return MakeReshapeHlo(new_shape, operand);
 }
 
 StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
@@ -241,7 +241,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
   c_reverse(new_shape_dim_bounds);
   Shape output_shape =
       ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds);
-  return CreateReshapeHlo(output_shape, operand);
+  return MakeReshapeHlo(output_shape, operand);
 }
 
 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
@@ -258,7 +258,7 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
   HloInstruction* zero =
       computation->AddInstruction(HloInstruction::CreateConstant(
           MakeUnique<Literal>(Literal::Zero(operand->shape().element_type()))));
-  return CreatePadHlo(operand, zero, padding_config);
+  return MakePadHlo(operand, zero, padding_config);
 }
 
 StatusOr<HloInstruction*> BroadcastZeros(
@@ -267,8 +267,8 @@ StatusOr<HloInstruction*> BroadcastZeros(
   HloInstruction* zero =
       computation->AddInstruction(HloInstruction::CreateConstant(
           MakeUnique<Literal>(Literal::Zero(element_type))));
-  return CreateBroadcastHlo(zero, /*broadcast_dimensions=*/{},
-                            /*result_shape_bounds=*/broadcast_dimensions);
+  return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
+                          /*result_shape_bounds=*/broadcast_dimensions);
 }
 
 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
index 2b03a84..6032eba 100644 (file)
@@ -28,73 +28,73 @@ namespace xla {
 
 // Creates a binary HLO instruction and adds it to the computation containing
 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
-StatusOr<HloInstruction*> CreateBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
-                                          HloInstruction* rhs);
+StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
+                                        HloInstruction* rhs);
 
 // Creates a pad HLO instruction and adds it to the computation containing
 // `operand` and `padding_value` (`operand` and `padding_value` must be in the
 // same computation).
-StatusOr<HloInstruction*> CreatePadHlo(HloInstruction* operand,
-                                       HloInstruction* padding_value,
-                                       const PaddingConfig& padding_config);
+StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
+                                     HloInstruction* padding_value,
+                                     const PaddingConfig& padding_config);
 
 // Creates a slice HLO instruction and adds it to the computation containing
 // `operand`.
-StatusOr<HloInstruction*> CreateSliceHlo(
+StatusOr<HloInstruction*> MakeSliceHlo(
     HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> start_indices,
     tensorflow::gtl::ArraySlice<int64> limit_indices,
     tensorflow::gtl::ArraySlice<int64> strides);
 
 // Creates a convolution HLO instruction and adds it to the computation
 // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
-StatusOr<HloInstruction*> CreateConvolveHlo(
+StatusOr<HloInstruction*> MakeConvolveHlo(
     HloInstruction* lhs, HloInstruction* rhs, const Window& window,
     const ConvolutionDimensionNumbers& dimension_numbers);
 
 // Creates a transpose HLO instruction and adds it to the computation containing
 // `operand`.
-StatusOr<HloInstruction*> CreateTransposeHlo(
+StatusOr<HloInstruction*> MakeTransposeHlo(
     HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions);
 
 // Creates a reshape HLO instruction and adds it to the computation containing
 // `operand`.
-StatusOr<HloInstruction*> CreateReshapeHlo(const Shape& result_shape,
-                                           HloInstruction* operand);
+StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
+                                         HloInstruction* operand);
 
-StatusOr<HloInstruction*> CreateReshapeHlo(
+StatusOr<HloInstruction*> MakeReshapeHlo(
     tensorflow::gtl::ArraySlice<int64> result_shape_dim_bounds,
     HloInstruction* operand);
 
 // Creates a dynamic-slice HLO instruction and adds it to the computation
 // containing `operand` and `start_indices` (`operand` and `start_indices` must
 // be in the same computation).
-StatusOr<HloInstruction*> CreateDynamicSliceHlo(
+StatusOr<HloInstruction*> MakeDynamicSliceHlo(
     HloInstruction* operand, HloInstruction* start_indices,
     tensorflow::gtl::ArraySlice<int64> slice_sizes);
 
 // Creates a dynamic-update-slice HLO instruction and adds it to the computation
 // containing `operand`, `update` and `start_indices` (`operand`, `update` and
 // `start_indices` must be in the same computation).
-StatusOr<HloInstruction*> CreateDynamicUpdateSliceHlo(
+StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
     HloInstruction* operand, HloInstruction* update,
     HloInstruction* start_indices);
 
 // Creates a broadcast HLO instruction and adds it to the computation containing
 // `operand`.
-StatusOr<HloInstruction*> CreateBroadcastHlo(
+StatusOr<HloInstruction*> MakeBroadcastHlo(
     HloInstruction* operand,
     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions,
     tensorflow::gtl::ArraySlice<int64> result_shape_bounds);
 
 // Creates a GetTupleElement HLO instruction and adds it to the computation
 // containing `operand`.
-StatusOr<HloInstruction*> CreateGetTupleElementHlo(HloInstruction* operand,
-                                                   int64 index);
+StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
+                                                 int64 index);
 
 // Creates a Concatenate HLO instruction and adds it to the computation
 // containing `operands` (`operands` must be non-empty and every element must be
 // contained in the same computation).
-StatusOr<HloInstruction*> CreateConcatHlo(
+StatusOr<HloInstruction*> MakeConcatHlo(
     tensorflow::gtl::ArraySlice<HloInstruction*> operands, int64 dimension);
 
 // -----------------------------------------------------------------------------
index 7441a7a..8cd5882 100644 (file)
@@ -155,10 +155,10 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
 
   HloInstruction* param = cond_computation->parameter_instruction(0);
   TF_ASSIGN_OR_RETURN(HloInstruction * counter,
-                      CreateGetTupleElementHlo(param, 0));
+                      MakeGetTupleElementHlo(param, 0));
   TF_ASSIGN_OR_RETURN(
       HloInstruction * compare,
-      CreateBinaryHlo(HloOpcode::kLt, counter, trip_count_constant));
+      MakeBinaryHlo(HloOpcode::kLt, counter, trip_count_constant));
   cond_computation->set_root_instruction(compare);
   return std::move(cond_computation);
 }
@@ -175,14 +175,14 @@ static StatusOr<std::unique_ptr<HloComputation>> MakeCountedLoopBodyComputation(
 
   HloInstruction* param = body_computation->parameter_instruction(0);
   TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
-                      CreateGetTupleElementHlo(param, 0));
+                      MakeGetTupleElementHlo(param, 0));
   TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar,
-                      CreateBinaryHlo(HloOpcode::kAdd, indvar, one));
+                      MakeBinaryHlo(HloOpcode::kAdd, indvar, one));
 
   std::vector<HloInstruction*> loop_body_generator_args;
   for (int64 i = 1, e = loop_state_shape.tuple_shapes_size(); i < e; i++) {
     TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element,
-                        CreateGetTupleElementHlo(param, i));
+                        MakeGetTupleElementHlo(param, i));
     loop_body_generator_args.push_back(tuple_element);
   }
   TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> next_state,
@@ -238,7 +238,7 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) {
   std::vector<HloInstruction*> result;
   for (int64 i = 0, e = init_values.size(); i < e; i++) {
     TF_ASSIGN_OR_RETURN(HloInstruction * user_state,
-                        CreateGetTupleElementHlo(while_instr, i + 1));
+                        MakeGetTupleElementHlo(while_instr, i + 1));
     result.push_back(user_state);
   }
   return result;