Pass HloOpcode instead of HloInstruction; NFC
authorSanjoy Das <sanjoy@google.com>
Tue, 29 May 2018 06:03:39 +0000 (23:03 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 06:06:35 +0000 (23:06 -0700)
Minor code cleanup change.

PiperOrigin-RevId: 198351045

tensorflow/compiler/xla/service/indexed_array_analysis.cc
tensorflow/compiler/xla/service/indexed_array_analysis.h

index 21af9a6..11d931c 100644 (file)
@@ -127,14 +127,16 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
     const HloInstruction* instr) {
   Array* computed_array;
   if (instr->IsElementwise() && instr->operand_count() == 1) {
-    TF_ASSIGN_OR_RETURN(computed_array,
-                        ComputeArrayForElementwiseUnaryOp(
-                            instr, FindOrDie(cache_, instr->operand(0))));
+    TF_ASSIGN_OR_RETURN(
+        computed_array,
+        ComputeArrayForElementwiseUnaryOp(
+            instr->opcode(), FindOrDie(cache_, instr->operand(0))));
   } else if (instr->IsElementwise() && instr->operand_count() == 2) {
-    TF_ASSIGN_OR_RETURN(computed_array,
-                        ComputeArrayForElementwiseBinaryOp(
-                            instr, FindOrDie(cache_, instr->operand(0)),
-                            FindOrDie(cache_, instr->operand(1))));
+    TF_ASSIGN_OR_RETURN(
+        computed_array,
+        ComputeArrayForElementwiseBinaryOp(
+            instr->opcode(), FindOrDie(cache_, instr->operand(0)),
+            FindOrDie(cache_, instr->operand(1))));
   } else if (instr->opcode() == HloOpcode::kConstant) {
     TF_ASSIGN_OR_RETURN(computed_array,
                         ComputeArrayForConstant(instr->literal()));
@@ -563,8 +565,9 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
 }
 
 StatusOr<Analysis::Array*>
-IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
-    const HloInstruction* instr, Array* lhs, Array* rhs) {
+IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
+                                                         Array* lhs,
+                                                         Array* rhs) {
   // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices))
   //          => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices)
   //
@@ -664,14 +667,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
     TF_ASSIGN_OR_RETURN(
         literal_for_new_source,
         TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
-            instr->opcode(), scalar_indexed_const->literal(),
-            *inner_broadcast_result)));
+            opcode, scalar_indexed_const->literal(), *inner_broadcast_result)));
   } else {
     TF_ASSIGN_OR_RETURN(
         literal_for_new_source,
         TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
-            instr->opcode(), *inner_broadcast_result,
-            scalar_indexed_const->literal())));
+            opcode, *inner_broadcast_result, scalar_indexed_const->literal())));
   }
 
   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
@@ -684,8 +685,8 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
 }
 
 StatusOr<Analysis::Array*>
-IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(
-    const HloInstruction* instr, Array* operand) {
+IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
+                                                        Array* operand) {
   auto* scalar_indexed_const =
       dynamic_cast<ScalarIndexedConstantArray*>(operand);
   if (operand == nullptr) {
@@ -697,7 +698,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(
 
   TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
                       TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp(
-                          instr->opcode(), scalar_indexed_const->literal())));
+                          opcode, scalar_indexed_const->literal())));
   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
   return Construct<ScalarIndexedConstantArray>(
       new_source, scalar_indexed_const->indices(),
index 561832a..ce92fd2 100644 (file)
@@ -268,10 +268,10 @@ class IndexedArrayAnalysis {
 
   StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand);
 
-  StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(
-      const HloInstruction* instr, Array* lhs, Array* rhs);
-  StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(
-      const HloInstruction* instr, Array* operand);
+  StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
+                                                      Array* lhs, Array* rhs);
+  StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
+                                                     Array* operand);
 
   template <typename T, typename... Args>
   T* Construct(Args&&... args) {