change GetFieldValueFromSqlBufferInternal<T> to check for XmlReader
authorWraith2 <wraith2@gmail.com>
Sat, 26 Jan 2019 22:12:30 +0000 (22:12 +0000)
committerWraith2 <wraith2@gmail.com>
Sat, 26 Jan 2019 22:12:30 +0000 (22:12 +0000)
add SqlTypeWorkarounds xml reader creation overload for TextReader
add DataStream test for GetFielvdValue<XmlReader> for value and null

Commit migrated from https://github.com/dotnet/corefx/commit/b4792140f978c102e9c104db6972d41646d12c98

src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SqlDataReader.cs
src/libraries/System.Data.SqlClient/src/System/Data/SqlTypes/SqlTypeWorkarounds.cs
src/libraries/System.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs

index e66d5cd..a509e71 100644 (file)
@@ -2588,10 +2588,10 @@ namespace System.Data.SqlClient
                 // If its a SQL Type or Nullable UDT
                 object rawValue = GetSqlValueFromSqlBufferInternal(data, metaData);
 
-                // Special case: User wants SqlString, but we have a SqlXml
-                // SqlXml can not be typecast into a SqlString, but we need to support SqlString on XML Types - so do a manual conversion
                 if (typeofT == s_typeofSqlString)
                 {
+                    // Special case: User wants SqlString, but we have a SqlXml
+                    // SqlXml can not be typecast into a SqlString, but we need to support SqlString on XML Types - so do a manual conversion
                     SqlXml xmlValue = rawValue as SqlXml;
                     if (xmlValue != null)
                     {
@@ -2610,22 +2610,58 @@ namespace System.Data.SqlClient
             }
             else
             {
-                // Otherwise Its a CLR or non-Nullable UDT
-                try
+                if (typeof(XmlReader) == typeofT)
                 {
-                    return (T)GetValueFromSqlBufferInternal(data, metaData);
+                    if (metaData.metaType.SqlDbType != SqlDbType.Xml)
+                    {
+                        throw SQL.XmlReaderNotSupportOnColumnType(metaData.column);
+                    }
+                    else
+                    {
+                        object clrValue = null;
+                        if (!data.IsNull)
+                        {
+                            clrValue = GetValueFromSqlBufferInternal(data, metaData);
+                        }
+                        if (clrValue is null) // covers IsNull and when there is data which is present but is a clr null somehow
+                        {
+                            return (T)(object)SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(
+                                new MemoryStream(Array.Empty<byte>(), writable: false),
+                                closeInput: true
+                            );
+                        }
+                        else if (clrValue.GetType() == typeof(string))
+                        {
+                            return (T)(object)SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(
+                                new StringReader(clrValue as string), 
+                                closeInput: true
+                            );
+                        }
+                        else
+                        {
+                            // try the type cast to throw the invalid cast exception and inform the user what types they're trying to use and that why it is wrong
+                            return (T)clrValue;
+                        }
+                    }
                 }
-                catch (InvalidCastException)
+                else
                 {
-                    if (data.IsNull)
+                    try
                     {
-                        // If the value was actually null, then we should throw a SqlNullValue instead
-                        throw SQL.SqlNullValue();
+                        return (T)GetValueFromSqlBufferInternal(data, metaData);
                     }
-                    else
+                    catch (InvalidCastException)
                     {
-                        // Legitimate InvalidCast, rethrow
-                        throw;
+                        if (data.IsNull)
+                        {
+                            // If the value was actually null, then we should throw a SqlNullValue instead
+                            throw SQL.SqlNullValue();
+                        }
+                        else
+                        {
+                            // Legitimate InvalidCast, rethrow
+                            throw;
+                        }
                     }
                 }
             }
index 1757441..0e18a0d 100644 (file)
@@ -37,6 +37,17 @@ namespace System.Data.SqlTypes
 
             return XmlReader.Create(stream, settingsToUse);
         }
+
+        internal static XmlReader SqlXmlCreateSqlXmlReader(TextReader textReader, bool closeInput = false, bool async = false)
+        {
+            Debug.Assert(closeInput || !async, "Currently we do not have pre-created settings for !closeInput+async");
+
+            XmlReaderSettings settingsToUse = closeInput ?
+                (async ? s_defaultXmlReaderSettingsAsyncCloseInput : s_defaultXmlReaderSettingsCloseInput) :
+                s_defaultXmlReaderSettings;
+
+            return XmlReader.Create(textReader, settingsToUse);
+        }
         #endregion
 
         #region Work around inability to access SqlDateTime.ToDateTime
index add701b..4354954 100644 (file)
@@ -294,6 +294,8 @@ namespace System.Data.SqlClient.ManualTesting.Tests
                     rdr.GetFieldValue<SqlXml>(15);
                     rdr.GetFieldValue<SqlString>(14);
                     rdr.GetFieldValue<SqlString>(15);
+                    rdr.GetFieldValue<XmlReader>(14);
+                    rdr.GetFieldValue<XmlReader>(15);
 
                     rdr.Read();
                     Assert.True(rdr.IsDBNullAsync(11).Result, "FAILED: IsDBNull was false for a null value");