Support op_Exponent and op_Exponentiation in S.L.Expressions (dotnet/corefx#26413)
authorJon Hanna <jon@hackcraft.net>
Mon, 22 Jan 2018 20:33:35 +0000 (20:33 +0000)
committerVladimir Sadov <vsadov@microsoft.com>
Mon, 22 Jan 2018 20:33:35 +0000 (12:33 -0800)
Provides better support within Expression.Power for VB and F# which
produce these operators.

Fixes dotnet/corefx#26411

Commit migrated from https://github.com/dotnet/corefx/commit/3102a879b949843e91563107f602eb9ff61c2afb

src/libraries/System.Linq.Expressions/src/System/Linq/Expressions/BinaryExpression.cs
src/libraries/System.Linq.Expressions/tests/BinaryOperators/Arithmetic/BinaryMultiplyTests.cs

index 60d45b9..87a7ed9 100644 (file)
@@ -2787,12 +2787,34 @@ namespace System.Linq.Expressions
             ExpressionUtils.RequiresCanRead(right, nameof(right));
             if (method == null)
             {
-                method = Math_Pow_Double_Double;
-                if (method == null)
+                if (left.Type == right.Type && left.Type.IsArithmetic())
+                {
+                    method = Math_Pow_Double_Double;
+                    Debug.Assert(method != null);
+                }
+                else
                 {
-                    throw Error.BinaryOperatorNotDefined(ExpressionType.Power, left.Type, right.Type);
+                    // VB uses op_Exponent, F# uses op_Exponentiation. This inconsistency is unfortunate, but we can
+                    // test for either.
+                    string name = "op_Exponent";
+                    BinaryExpression b = GetUserDefinedBinaryOperator(ExpressionType.Power, name, left, right, liftToNull: true);
+                    if (b == null)
+                    {
+                        name = "op_Exponentiation";
+                        b = GetUserDefinedBinaryOperator(ExpressionType.Power, name, left, right, liftToNull: true);
+                        if (b == null)
+                        {
+                            throw Error.BinaryOperatorNotDefined(ExpressionType.Power, left.Type, right.Type);
+                        }
+                    }
+
+                    ParameterInfo[] pis = b.Method.GetParametersCached();
+                    ValidateParamswithOperandsOrThrow(pis[0].ParameterType, left.Type, ExpressionType.Power, name);
+                    ValidateParamswithOperandsOrThrow(pis[1].ParameterType, right.Type, ExpressionType.Power, name);
+                    return b;
                 }
             }
+
             return GetMethodBasedBinaryOperator(ExpressionType.Power, left, right, method, liftToNull: true);
         }
 
index 7134ea7..3914674 100644 (file)
@@ -576,5 +576,87 @@ namespace System.Linq.Expressions.Tests
             BinaryExpression e2 = Expression.MultiplyChecked(Expression.Parameter(typeof(int), "a"), Expression.Parameter(typeof(int), "b"));
             Assert.Equal("(a * b)", e2.ToString());
         }
+
+        // Simulate VB-style overloading of exponentiation operation
+        public struct VBStyleExponentiation
+        {
+            public VBStyleExponentiation(double value) => Value = value;
+
+            public double Value { get; }
+
+            public static implicit operator VBStyleExponentiation(double value) => new VBStyleExponentiation(value);
+
+            public static VBStyleExponentiation op_Exponent(VBStyleExponentiation x, VBStyleExponentiation y) => Math.Pow(x.Value, y.Value);
+        }
+
+        [Theory, ClassData(typeof(CompilationTypes))]
+        public static void VBStyleOperatorOverloading(bool useInterpreter)
+        {
+            var b = Expression.Parameter(typeof(VBStyleExponentiation));
+            var e = Expression.Parameter(typeof(VBStyleExponentiation));
+            var func = Expression.Lambda<Func<VBStyleExponentiation, VBStyleExponentiation, VBStyleExponentiation>>(
+                    Expression.Power(b, e), b, e).Compile(useInterpreter);
+            Assert.Equal(8.0, func(2.0, 3.0).Value);
+            Assert.Equal(10000.0, func(10.0, 4.0).Value);
+        }
+
+        [Theory, ClassData(typeof(CompilationTypes))]
+        public static void VBStyleOperatorOverloadingLifted(bool useInterpreter)
+        {
+            var b = Expression.Parameter(typeof(VBStyleExponentiation?));
+            var e = Expression.Parameter(typeof(VBStyleExponentiation?));
+            var func = Expression.Lambda<Func<VBStyleExponentiation?, VBStyleExponentiation?, VBStyleExponentiation?>>(
+                Expression.Power(b, e), b, e).Compile(useInterpreter);
+            Assert.Equal(8.0, func(2.0, 3.0).Value.Value);
+            Assert.Equal(10000.0, func(10.0, 4.0).Value.Value);
+            Assert.Null(func(2.0, null));
+            Assert.Null(func(null, 2.0));
+            Assert.Null(func(null, null));
+        }
+
+        // Simulate F#-style overloading of exponentiation operation
+        public struct FSStyleExponentiation
+        {
+            public FSStyleExponentiation(double value) => Value = value;
+
+            public static implicit operator FSStyleExponentiation(double value) => new FSStyleExponentiation(value);
+
+            public double Value { get; }
+
+            public static FSStyleExponentiation op_Exponentiation(FSStyleExponentiation x, FSStyleExponentiation y)
+                => new FSStyleExponentiation(Math.Pow(x.Value, y.Value));
+        }
+
+        [Theory, ClassData(typeof(CompilationTypes))]
+        public static void FSStyleOperatorOverloading(bool useInterpreter)
+        {
+            var b = Expression.Parameter(typeof(FSStyleExponentiation));
+            var e = Expression.Parameter(typeof(FSStyleExponentiation));
+            var func = Expression.Lambda<Func<FSStyleExponentiation, FSStyleExponentiation, FSStyleExponentiation>>(
+                Expression.Power(b, e), b, e).Compile(useInterpreter);
+            Assert.Equal(8.0, func(2.0, 3.0).Value);
+            Assert.Equal(10000.0, func(10.0, 4.0).Value);
+        }
+
+        [Theory, ClassData(typeof(CompilationTypes))]
+        public static void FSStyleOperatorOverloadingLifted(bool useInterpreter)
+        {
+            var b = Expression.Parameter(typeof(FSStyleExponentiation?));
+            var e = Expression.Parameter(typeof(FSStyleExponentiation?));
+            var func = Expression.Lambda<Func<FSStyleExponentiation?, FSStyleExponentiation?, FSStyleExponentiation?>>(
+                Expression.Power(b, e), b, e).Compile(useInterpreter);
+            Assert.Equal(8.0, func(2.0, 3.0).Value.Value);
+            Assert.Equal(10000.0, func(10.0, 4.0).Value.Value);
+            Assert.Null(func(2.0, null));
+            Assert.Null(func(null, 2.0));
+            Assert.Null(func(null, null));
+        }
+
+        [Fact]
+        public static void ExponentiationNotSupported()
+        {
+            ConstantExpression arg = Expression.Constant("");
+            Assert.Throws<InvalidOperationException>(() => Expression.Power(arg, arg));
+        }
     }
 }