Add async System.Data resultset and database schema APIs (#39098)
authorShay Rojansky <roji@roji.org>
Sun, 12 Jul 2020 13:56:20 +0000 (16:56 +0300)
committerGitHub <noreply@github.com>
Sun, 12 Jul 2020 13:56:20 +0000 (15:56 +0200)
Closes #38028

src/libraries/System.Data.Common/ref/System.Data.Common.cs
src/libraries/System.Data.Common/src/System/Data/Common/DbConnection.cs
src/libraries/System.Data.Common/src/System/Data/Common/DbDataReader.cs
src/libraries/System.Data.Common/tests/System/Data/Common/DbConnectionTests.cs
src/libraries/System.Data.Common/tests/System/Data/Common/DbDataReaderMock.cs
src/libraries/System.Data.Common/tests/System/Data/Common/DbDataReaderTest.cs

index eb58830..c5a8f8d 100644 (file)
@@ -1933,6 +1933,9 @@ namespace System.Data.Common
         public virtual System.Data.DataTable GetSchema() { throw null; }
         public virtual System.Data.DataTable GetSchema(string collectionName) { throw null; }
         public virtual System.Data.DataTable GetSchema(string collectionName, string?[] restrictionValues) { throw null; }
+        public virtual System.Threading.Tasks.Task<System.Data.DataTable> GetSchemaAsync(System.Threading.CancellationToken cancellationToken = default) { throw null; }
+        public virtual System.Threading.Tasks.Task<System.Data.DataTable> GetSchemaAsync(string collectionName, System.Threading.CancellationToken cancellationToken = default) { throw null; }
+        public virtual System.Threading.Tasks.Task<System.Data.DataTable> GetSchemaAsync(string collectionName, string?[] restrictionValues, System.Threading.CancellationToken cancellationToken = default) { throw null; }
         protected virtual void OnStateChange(System.Data.StateChangeEventArgs stateChange) { }
         public abstract void Open();
         public System.Threading.Tasks.Task OpenAsync() { throw null; }
@@ -2111,6 +2114,8 @@ namespace System.Data.Common
         [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
         public virtual int GetProviderSpecificValues(object[] values) { throw null; }
         public virtual System.Data.DataTable GetSchemaTable() { throw null; }
+        public virtual System.Threading.Tasks.Task<System.Data.DataTable> GetSchemaTableAsync(System.Threading.CancellationToken cancellationToken = default) { throw null; }
+        public virtual System.Threading.Tasks.Task<System.Collections.ObjectModel.ReadOnlyCollection<System.Data.Common.DbColumn>> GetColumnSchemaAsync(System.Threading.CancellationToken cancellationToken = default) { throw null; }
         public virtual System.IO.Stream GetStream(int ordinal) { throw null; }
         public abstract string GetString(int ordinal);
         public virtual System.IO.TextReader GetTextReader(int ordinal) { throw null; }
index 130611b..7b99a8d 100644 (file)
@@ -143,21 +143,166 @@ namespace System.Data.Common
 
         // these need to be here so that GetSchema is visible when programming to a dbConnection object.
         // they are overridden by the real implementations in DbConnectionBase
+
+        /// <summary>
+        /// Returns schema information for the data source of this <see cref="DbConnection" />.
+        /// </summary>
+        /// <returns>A <see cref="DataTable" /> that contains schema information.</returns>
+        /// <remarks>
+        /// If the connection is associated with a transaction, executing <see cref="GetSchema()" /> calls may cause
+        /// some providers to throw an exception.
+        /// </remarks>
         public virtual DataTable GetSchema()
         {
             throw ADP.NotSupported();
         }
 
+        /// <summary>
+        /// Returns schema information for the data source of this <see cref="DbConnection" /> using the specified
+        /// string for the schema name.
+        /// </summary>
+        /// <param name="collectionName">Specifies the name of the schema to return.</param>
+        /// <returns>A <see cref="DataTable" /> that contains schema information.</returns>
+        /// <exception cref="ArgumentException">
+        /// <paramref name="collectionName" /> is specified as <see langword="null" />.
+        /// </exception>
+        /// <remarks>
+        /// If the connection is associated with a transaction, executing <see cref="GetSchema(string)" /> calls may cause
+        /// some providers to throw an exception.
+        /// </remarks>
         public virtual DataTable GetSchema(string collectionName)
         {
             throw ADP.NotSupported();
         }
 
+        /// <summary>
+        /// Returns schema information for the data source of this <see cref="DbConnection" /> using the specified
+        /// string for the schema name and the specified string array for the restriction values.
+        /// </summary>
+        /// <param name="collectionName">Specifies the name of the schema to return.</param>
+        /// <param name="restrictionValues">Specifies a set of restriction values for the requested schema.</param>
+        /// <returns>A <see cref="DataTable" /> that contains schema information.</returns>
+        /// <exception cref="ArgumentException">
+        /// <paramref name="collectionName" /> is specified as <see langword="null" />.
+        /// </exception>
+        /// <remarks>
+        /// <para>
+        /// The <paramref name="restrictionValues" /> parameter can supply n depth of values, which are specified by the
+        /// restrictions collection for a specific collection. In order to set values on a given restriction, and not
+        /// set the values of other restrictions, you need to set the preceding restrictions to null and then put the
+        /// appropriate value in for the restriction that you would like to specify a value for.
+        /// </para>
+        /// <para>
+        /// An example of this is the "Tables" collection. If the "Tables" collection has three restrictions (database,
+        /// owner, and table name) and you want to get back only the tables associated with the owner "Carl", you must
+        /// pass in the following values at least: null, "Carl". If a restriction value is not passed in, the default
+        /// values are used for that restriction. This is the same mapping as passing in null, which is different from
+        /// passing in an empty string for the parameter value. In that case, the empty string ("") is considered to be
+        /// the value for the specified parameter.
+        /// </para>
+        /// <para>
+        /// If the connection is associated with a transaction, executing <see cref="GetSchema(string, string[])" />
+        /// calls may cause some providers to throw an exception.
+        /// </para>
+        /// </remarks>
         public virtual DataTable GetSchema(string collectionName, string?[] restrictionValues)
         {
             throw ADP.NotSupported();
         }
 
+        /// <summary>
+        /// This is the asynchronous version of <see cref="GetSchema()" />.
+        /// Providers should override with an appropriate implementation.
+        /// The cancellation token can optionally be honored.
+        /// The default implementation invokes the synchronous <see cref="GetSchema()" /> call and returns a completed
+        /// task.
+        /// The default implementation will return a cancelled task if passed an already cancelled cancellationToken.
+        /// Exceptions thrown by <see cref="GetSchema()" /> will be communicated via the returned Task Exception
+        /// property.
+        /// </summary>
+        /// <param name="cancellationToken">The cancellation instruction.</param>
+        /// <returns>A task representing the asynchronous operation.</returns>
+        public virtual Task<DataTable> GetSchemaAsync(CancellationToken cancellationToken = default)
+        {
+            if (cancellationToken.IsCancellationRequested)
+            {
+                return Task.FromCanceled<DataTable>(cancellationToken);
+            }
+
+            try
+            {
+                return Task.FromResult(GetSchema());
+            }
+            catch (Exception e)
+            {
+                return Task.FromException<DataTable>(e);
+            }
+        }
+
+        /// <summary>
+        /// This is the asynchronous version of <see cref="GetSchema(string)" />.
+        /// Providers should override with an appropriate implementation.
+        /// The cancellation token can optionally be honored.
+        /// The default implementation invokes the synchronous <see cref="GetSchema(string)" /> call and returns a
+        /// completed task.
+        /// The default implementation will return a cancelled task if passed an already cancelled cancellationToken.
+        /// Exceptions thrown by <see cref="GetSchema(string)" /> will be communicated via the returned Task Exception
+        /// property.
+        /// </summary>
+        /// <param name="collectionName">Specifies the name of the schema to return.</param>
+        /// <param name="cancellationToken">The cancellation instruction.</param>
+        /// <returns>A task representing the asynchronous operation.</returns>
+        public virtual Task<DataTable> GetSchemaAsync(
+            string collectionName,
+            CancellationToken cancellationToken = default)
+        {
+            if (cancellationToken.IsCancellationRequested)
+            {
+                return Task.FromCanceled<DataTable>(cancellationToken);
+            }
+
+            try
+            {
+                return Task.FromResult(GetSchema(collectionName));
+            }
+            catch (Exception e)
+            {
+                return Task.FromException<DataTable>(e);
+            }
+        }
+
+        /// <summary>
+        /// This is the asynchronous version of <see cref="GetSchema(string, string[])" />.
+        /// Providers should override with an appropriate implementation.
+        /// The cancellation token can optionally be honored.
+        /// The default implementation invokes the synchronous <see cref="GetSchema(string, string[])" /> call and
+        /// returns a completed task.
+        /// The default implementation will return a cancelled task if passed an already cancelled cancellationToken.
+        /// Exceptions thrown by <see cref="GetSchema(string, string[])" /> will be communicated via the returned Task
+        /// Exception property.
+        /// </summary>
+        /// <param name="collectionName">Specifies the name of the schema to return.</param>
+        /// <param name="restrictionValues">Specifies a set of restriction values for the requested schema.</param>
+        /// <param name="cancellationToken">The cancellation instruction.</param>
+        /// <returns>A task representing the asynchronous operation.</returns>
+        public virtual Task<DataTable> GetSchemaAsync(string collectionName, string?[] restrictionValues,
+            CancellationToken cancellationToken = default)
+        {
+            if (cancellationToken.IsCancellationRequested)
+            {
+                return Task.FromCanceled<DataTable>(cancellationToken);
+            }
+
+            try
+            {
+                return Task.FromResult(GetSchema(collectionName, restrictionValues));
+            }
+            catch (Exception e)
+            {
+                return Task.FromException<DataTable>(e);
+            }
+        }
+
         protected virtual void OnStateChange(StateChangeEventArgs stateChange)
         {
             if (_suppressStateChangeForReconnection)
index 2c1e028..4753058 100644 (file)
@@ -3,6 +3,7 @@
 
 #nullable enable
 using System.Collections;
+using System.Collections.ObjectModel;
 using System.ComponentModel;
 using System.IO;
 using System.Threading.Tasks;
@@ -73,11 +74,77 @@ namespace System.Data.Common
 
         public abstract int GetOrdinal(string name);
 
+        /// <summary>
+        /// Returns a <see cref="DataTable" /> that describes the column metadata of the ><see cref="DbDataReader" />.
+        /// </summary>
+        /// <returns>A <see cref="DataTable" /> that describes the column metadata.</returns>
+        /// <exception cref="InvalidOperationException">The <see cref="DbDataReader" /> is closed.</exception>
+        /// <exception cref="IndexOutOfRangeException">The column index is out of range.</exception>
+        /// <exception cref="NotSupportedException">.NET Core only: This member is not supported.</exception>
         public virtual DataTable GetSchemaTable()
         {
             throw new NotSupportedException();
         }
 
+        /// <summary>
+        /// This is the asynchronous version of <see cref="GetSchemaTable()" />.
+        /// Providers should override with an appropriate implementation.
+        /// The cancellation token can optionally be honored.
+        /// The default implementation invokes the synchronous <see cref="GetSchemaTable()" /> call and
+        /// returns a completed task.
+        /// The default implementation will return a cancelled task if passed an already cancelled cancellationToken.
+        /// Exceptions thrown by <see cref="GetSchemaTable()" /> will be communicated via the returned Task
+        /// Exception property.
+        /// </summary>
+        /// <param name="cancellationToken">The cancellation instruction.</param>
+        /// <returns>A task representing the asynchronous operation.</returns>
+        public virtual Task<DataTable> GetSchemaTableAsync(CancellationToken cancellationToken = default)
+        {
+            if (cancellationToken.IsCancellationRequested)
+            {
+                return Task.FromCanceled<DataTable>(cancellationToken);
+            }
+
+            try
+            {
+                return Task.FromResult(GetSchemaTable());
+            }
+            catch (Exception e)
+            {
+                return Task.FromException<DataTable>(e);
+            }
+        }
+
+        /// <summary>
+        /// This is the asynchronous version of <see cref="DbDataReaderExtensions.GetColumnSchema(DbDataReader)" />.
+        /// Providers should override with an appropriate implementation.
+        /// The cancellation token can optionally be honored.
+        /// The default implementation invokes the synchronous
+        /// <see cref="DbDataReaderExtensions.GetColumnSchema(DbDataReader)" /> call and returns a completed task.
+        /// The default implementation will return a cancelled task if passed an already cancelled cancellationToken.
+        /// Exceptions thrown by <see cref="DbDataReaderExtensions.GetColumnSchema(DbDataReader)" /> will be
+        /// communicated via the returned Task Exception property.
+        /// </summary>
+        /// <param name="cancellationToken">The cancellation instruction.</param>
+        /// <returns>A task representing the asynchronous operation.</returns>
+        public virtual Task<ReadOnlyCollection<DbColumn>> GetColumnSchemaAsync(
+            CancellationToken cancellationToken = default)
+        {
+            if (cancellationToken.IsCancellationRequested)
+            {
+                return Task.FromCanceled<ReadOnlyCollection<DbColumn>>(cancellationToken);
+            }
+
+            try
+            {
+                return Task.FromResult(this.GetColumnSchema());
+            }
+            catch (Exception e)
+            {
+                return Task.FromException<ReadOnlyCollection<DbColumn>>(e);
+            }
+        }
+
         public abstract bool GetBoolean(int ordinal);
 
         public abstract byte GetByte(int ordinal);
index 89840ac..6a6fd29 100644 (file)
@@ -1,7 +1,12 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+#nullable enable
+
+using System.Diagnostics.CodeAnalysis;
 using System.Reflection;
+using System.Threading;
+using System.Threading.Tasks;
 using Xunit;
 
 namespace System.Data.Common.Tests
@@ -10,124 +15,70 @@ namespace System.Data.Common.Tests
     {
         private static volatile bool _wasFinalized;
 
-        private class FinalizingConnection : DbConnection
+        private class MockDbConnection : DbConnection
         {
-            public static void CreateAndRelease()
+            [AllowNull]
+            public override string ConnectionString
             {
-                new FinalizingConnection();
+                get => throw new NotImplementedException();
+                set => throw new NotImplementedException();
             }
 
+            public override string Database => throw new NotImplementedException();
+            public override string DataSource => throw new NotImplementedException();
+            public override string ServerVersion => throw new NotImplementedException();
+            public override ConnectionState State => throw new NotImplementedException();
+            public override void ChangeDatabase(string databaseName) => throw new NotImplementedException();
+            public override void Close() => throw new NotImplementedException();
+            public override void Open() => throw new NotImplementedException();
+            protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => throw new NotImplementedException();
+            protected override DbCommand CreateDbCommand() => throw new NotImplementedException();
+        }
+
+        private class FinalizingConnection : MockDbConnection
+        {
+            public static void CreateAndRelease() => new FinalizingConnection();
+
             protected override void Dispose(bool disposing)
             {
                 if (!disposing)
                     _wasFinalized = true;
                 base.Dispose(disposing);
             }
-
-            public override string ConnectionString
-            {
-                get
-                {
-                    throw new NotImplementedException();
-                }
-
-                set
-                {
-                    throw new NotImplementedException();
-                }
-            }
-
-            public override string Database
-            {
-                get
-                {
-                    throw new NotImplementedException();
-                }
-            }
-
-            public override string DataSource
-            {
-                get
-                {
-                    throw new NotImplementedException();
-                }
-            }
-
-            public override string ServerVersion
-            {
-                get
-                {
-                    throw new NotImplementedException();
-                }
-            }
-
-            public override ConnectionState State
-            {
-                get
-                {
-                    throw new NotImplementedException();
-                }
-            }
-
-            public override void ChangeDatabase(string databaseName)
-            {
-                throw new NotImplementedException();
-            }
-
-            public override void Close()
-            {
-                throw new NotImplementedException();
-            }
-
-            public override void Open()
-            {
-                throw new NotImplementedException();
-            }
-
-            protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel)
-            {
-                throw new NotImplementedException();
-            }
-
-            protected override DbCommand CreateDbCommand()
-            {
-                throw new NotImplementedException();
-            }
         }
 
-        private class DbProviderFactoryConnection : DbConnection
+        private class GetSchemaConnection : MockDbConnection
         {
-            protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel)
+            public override DataTable GetSchema()
             {
-                throw new NotImplementedException();
+                var table = new DataTable();
+                table.Columns.Add(new DataColumn("CollectionName", typeof(string)));
+                table.Columns.Add(new DataColumn("WithRestrictions", typeof(bool)));
+                table.Rows.Add("Default", false);
+                return table;
             }
 
-            public override void ChangeDatabase(string databaseName)
+            public override DataTable GetSchema(string collectionName)
             {
-                throw new NotImplementedException();
+                var table = new DataTable();
+                table.Columns.Add(new DataColumn("CollectionName", typeof(string)));
+                table.Columns.Add(new DataColumn("WithRestrictions", typeof(bool)));
+                table.Rows.Add(collectionName, false);
+                return table;
             }
 
-            public override void Close()
+            public override DataTable GetSchema(string collectionName, string?[] restrictionValues)
             {
-                throw new NotImplementedException();
-            }
-
-            public override void Open()
-            {
-                throw new NotImplementedException();
-            }
-
-            public override string ConnectionString { get; set; }
-            public override string Database { get; }
-            public override ConnectionState State { get; }
-            public override string DataSource { get; }
-            public override string ServerVersion { get; }
-
-            protected override DbCommand CreateDbCommand()
-            {
-                throw new NotImplementedException();
+                var table = new DataTable();
+                table.Columns.Add(new DataColumn("CollectionName", typeof(string)));
+                table.Columns.Add(new DataColumn("WithRestrictions", typeof(bool)));
+                table.Rows.Add(collectionName, true);
+                return table;
             }
+        }
 
+        private class DbProviderFactoryConnection : MockDbConnection
+        {
             protected override DbProviderFactory DbProviderFactory => TestDbProviderFactory.Instance;
         }
 
@@ -150,12 +101,48 @@ namespace System.Data.Common.Tests
         public void ProviderFactoryTest()
         {
             DbProviderFactoryConnection con = new DbProviderFactoryConnection();
-            PropertyInfo providerFactoryProperty = con.GetType().GetProperty("ProviderFactory", BindingFlags.NonPublic | BindingFlags.Instance);
+            PropertyInfo providerFactoryProperty = con.GetType().GetProperty("ProviderFactory", BindingFlags.NonPublic | BindingFlags.Instance)!;
             Assert.NotNull(providerFactoryProperty);
-            DbProviderFactory factory = providerFactoryProperty.GetValue(con) as DbProviderFactory;
+            DbProviderFactory? factory = providerFactoryProperty.GetValue(con) as DbProviderFactory;
             Assert.NotNull(factory);
-            Assert.Same(typeof(TestDbProviderFactory), factory.GetType());
+            Assert.Same(typeof(TestDbProviderFactory), factory!.GetType());
             Assert.Same(TestDbProviderFactory.Instance, factory);
         }
+
+        [Fact]
+        public void GetSchemaAsync_with_cancelled_token()
+        {
+            var conn = new MockDbConnection();
+            Assert.ThrowsAsync<TaskCanceledException>(async () => await conn.GetSchemaAsync(new CancellationToken(true)));
+            Assert.ThrowsAsync<TaskCanceledException>(async () => await conn.GetSchemaAsync("MetaDataCollections", new CancellationToken(true)));
+            Assert.ThrowsAsync<TaskCanceledException>(async () => await conn.GetSchemaAsync("MetaDataCollections", new string[0], new CancellationToken(true)));
+        }
+
+        [Fact]
+        public void GetSchemaAsync_with_exception()
+        {
+            var conn = new MockDbConnection();
+            Assert.ThrowsAsync<NotSupportedException>(async () => await conn.GetSchemaAsync());
+            Assert.ThrowsAsync<NotSupportedException>(async () => await conn.GetSchemaAsync("MetaDataCollections"));
+            Assert.ThrowsAsync<NotSupportedException>(async () => await conn.GetSchemaAsync("MetaDataCollections", new string[0]));
+        }
+
+        [Fact]
+        public async Task GetSchemaAsync_calls_GetSchema()
+        {
+            var conn = new GetSchemaConnection();
+
+            var row = (await conn.GetSchemaAsync()).Rows[0];
+            Assert.Equal("Default", row["CollectionName"]);
+            Assert.Equal(false, row["WithRestrictions"]);
+
+            row = (await conn.GetSchemaAsync("MetaDataCollections")).Rows[0];
+            Assert.Equal("MetaDataCollections", row["CollectionName"]);
+            Assert.Equal(false, row["WithRestrictions"]);
+
+            row = (await conn.GetSchemaAsync("MetaDataCollections", new string?[0])).Rows[0];
+            Assert.Equal("MetaDataCollections", row["CollectionName"]);
+            Assert.Equal(true, row["WithRestrictions"]);
+        }
     }
 }
index 00981a8..0735e90 100644 (file)
 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 //
 
+#nullable enable
 
+using System.Collections;
+using System.Linq;
 using System.Data.Common;
 
-namespace System.Data.Tests.Common
+namespace System.Data.Common.Tests
 {
     internal class DbDataReaderMock : DbDataReader
     {
-        private int _currentRowIndex = -1;
-        private DataTable _testDataTable;
+        protected int _currentRowIndex = -1;
+        protected DataTable _testDataTable;
 
         public DbDataReaderMock()
-        {
-            _testDataTable = new DataTable();
-        }
+            => _testDataTable = new DataTable();
 
         public DbDataReaderMock(DataTable testData)
-        {
-            _testDataTable = testData ?? throw new ArgumentNullException(nameof(testData));
-        }
-
-        public override void Close()
-        {
-            _testDataTable.Clear();
-        }
-
-        public override int Depth
-        {
-            get { throw new NotImplementedException(); }
-        }
-
-        public override int FieldCount
-        {
-            get { throw new NotImplementedException(); }
-        }
-
-        public override bool GetBoolean(int ordinal)
-        {
-            return (bool)GetValue(ordinal);
-        }
+            => _testDataTable = testData ?? throw new ArgumentNullException(nameof(testData));
 
-        public override byte GetByte(int ordinal)
-        {
-            return (byte)GetValue(ordinal);
-        }
+        public override void Close() => _testDataTable.Clear();
+        public override int Depth => throw new NotImplementedException();
+        public override int FieldCount => throw new NotImplementedException();
+        public override bool GetBoolean(int ordinal) => (bool)GetValue(ordinal);
+        public override byte GetByte(int ordinal) => (byte)GetValue(ordinal);
 
-        public override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length)
+        public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int bufferOffset, int length)
         {
             object value = GetValue(ordinal);
             if (value == DBNull.Value)
@@ -78,17 +58,19 @@ namespace System.Data.Tests.Common
             }
 
             byte[] data = (byte[])value;
+            if (buffer is null)
+            {
+                return data.Length;
+            }
+
             long bytesToRead = Math.Min(data.Length - dataOffset, length);
             Buffer.BlockCopy(data, (int)dataOffset, buffer, bufferOffset, (int)bytesToRead);
             return bytesToRead;
         }
 
-        public override char GetChar(int ordinal)
-        {
-            return (char)GetValue(ordinal);
-        }
+        public override char GetChar(int ordinal) => (char)GetValue(ordinal);
 
-        public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length)
+        public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int bufferOffset, int length)
         {
             object value = GetValue(ordinal);
             if (value == DBNull.Value)
@@ -96,71 +78,28 @@ namespace System.Data.Tests.Common
                 return 0;
             }
 
-            char[] data = value.ToString().ToCharArray();
+            char[] data = value.ToString()!.ToCharArray();
+            if (buffer is null)
+            {
+                return data.Length;
+            }
             long bytesToRead = Math.Min(data.Length - dataOffset, length);
             Array.Copy(data, dataOffset, buffer, bufferOffset, bytesToRead);
             return bytesToRead;
         }
 
-        public override string GetDataTypeName(int ordinal)
-        {
-            throw new NotImplementedException();
-        }
-
-        public override DateTime GetDateTime(int ordinal)
-        {
-            return (DateTime)GetValue(ordinal);
-        }
-
-        public override decimal GetDecimal(int ordinal)
-        {
-            return (decimal)GetValue(ordinal);
-        }
-
-        public override double GetDouble(int ordinal)
-        {
-            return (double)GetValue(ordinal);
-        }
-
-        public override global::System.Collections.IEnumerator GetEnumerator()
-        {
-            throw new NotImplementedException();
-        }
-
-        public override Type GetFieldType(int ordinal)
-        {
-            throw new NotImplementedException();
-        }
-
-        public override float GetFloat(int ordinal)
-        {
-            return (float)GetValue(ordinal);
-        }
-
-        public override Guid GetGuid(int ordinal)
-        {
-            return (Guid)GetValue(ordinal);
-        }
-
-        public override short GetInt16(int ordinal)
-        {
-            return (short)GetValue(ordinal);
-        }
-
-        public override int GetInt32(int ordinal)
-        {
-            return (int)GetValue(ordinal);
-        }
-
-        public override long GetInt64(int ordinal)
-        {
-            return (long)GetValue(ordinal);
-        }
-
-        public override string GetName(int ordinal)
-        {
-            return _testDataTable.Columns[ordinal].ColumnName;
-        }
+        public override string GetDataTypeName(int ordinal) => throw new NotImplementedException();
+        public override DateTime GetDateTime(int ordinal) => (DateTime)GetValue(ordinal);
+        public override decimal GetDecimal(int ordinal) => (decimal)GetValue(ordinal);
+        public override double GetDouble(int ordinal) => (double)GetValue(ordinal);
+        public override IEnumerator GetEnumerator() => throw new NotImplementedException();
+        public override Type GetFieldType(int ordinal) => throw new NotImplementedException();
+        public override float GetFloat(int ordinal) => (float)GetValue(ordinal);
+        public override Guid GetGuid(int ordinal) => (Guid)GetValue(ordinal);
+        public override short GetInt16(int ordinal) => (short)GetValue(ordinal);
+        public override int GetInt32(int ordinal) => (int)GetValue(ordinal);
+        public override long GetInt64(int ordinal) => (long)GetValue(ordinal);
+        public override string GetName(int ordinal) => _testDataTable.Columns[ordinal].ColumnName;
 
         public override int GetOrdinal(string name)
         {
@@ -178,45 +117,13 @@ namespace System.Data.Tests.Common
             return -1;
         }
 
-        public override DataTable GetSchemaTable()
-        {
-            throw new NotImplementedException();
-        }
-
-        public override string GetString(int ordinal)
-        {
-            return (string)_testDataTable.Rows[_currentRowIndex][ordinal];
-        }
-
-        public override object GetValue(int ordinal)
-        {
-            return _testDataTable.Rows[_currentRowIndex][ordinal];
-        }
-
-        public override int GetValues(object[] values)
-        {
-            throw new NotImplementedException();
-        }
-
-        public override bool HasRows
-        {
-            get { throw new NotImplementedException(); }
-        }
-
-        public override bool IsClosed
-        {
-            get { throw new NotImplementedException(); }
-        }
-
-        public override bool IsDBNull(int ordinal)
-        {
-            return _testDataTable.Rows[_currentRowIndex][ordinal] == DBNull.Value;
-        }
-
-        public override bool NextResult()
-        {
-            throw new NotImplementedException();
-        }
+        public override string GetString(int ordinal) => (string)_testDataTable.Rows[_currentRowIndex][ordinal];
+        public override object GetValue(int ordinal) => _testDataTable.Rows[_currentRowIndex][ordinal];
+        public override int GetValues(object[] values) => throw new NotImplementedException();
+        public override bool HasRows => throw new NotImplementedException();
+        public override bool IsClosed => throw new NotImplementedException();
+        public override bool IsDBNull(int ordinal) => _testDataTable.Rows[_currentRowIndex][ordinal] == DBNull.Value;
+        public override bool NextResult() => throw new NotImplementedException();
 
         public override bool Read()
         {
@@ -224,19 +131,30 @@ namespace System.Data.Tests.Common
             return _currentRowIndex < _testDataTable.Rows.Count;
         }
 
-        public override int RecordsAffected
-        {
-            get { throw new NotImplementedException(); }
-        }
+        public override int RecordsAffected => throw new NotImplementedException();
+        public override object this[string name] => throw new NotImplementedException();
+        public override object this[int ordinal] => throw new NotImplementedException();
+    }
 
-        public override object this[string name]
-        {
-            get { throw new NotImplementedException(); }
-        }
+    internal class SchemaDbDataReaderMock : DbDataReaderMock
+    {
+        public SchemaDbDataReaderMock(DataTable testData) : base(testData) {}
 
-        public override object this[int ordinal]
+        public override DataTable GetSchemaTable()
         {
-            get { throw new NotImplementedException(); }
+            var table = new DataTable("SchemaTable");
+            table.Columns.Add("ColumnName", typeof(string));
+            table.Columns.Add("DataType", typeof(Type));
+
+            foreach (var column in _testDataTable.Columns.Cast<DataColumn>())
+            {
+                var row = table.NewRow();
+                row["ColumnName"] = column.ColumnName;
+                row["DataType"] = column.DataType;
+                table.Rows.Add(row);
+            }
+
+            return table;
         }
     }
 }
index 04e54ce..be9460d 100644 (file)
 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 
+#nullable enable
+
+using System.Linq;
 using System.IO;
 using System.Threading;
 using System.Threading.Tasks;
 using Xunit;
 
-namespace System.Data.Tests.Common
+namespace System.Data.Common.Tests
 {
     public class DbDataReaderTest
     {
@@ -499,6 +502,46 @@ namespace System.Data.Tests.Common
             return Assert.ThrowsAsync<TaskCanceledException>(() => _dataReader.IsDBNullAsync("dbnull_col", new CancellationToken(true)));
         }
 
+        [Fact]
+        public void GetSchemaTableAsync_with_cancelled_token()
+            => Assert.ThrowsAsync<TaskCanceledException>(async () => await new DbDataReaderMock().GetSchemaTableAsync(new CancellationToken(true)));
+
+        [Fact]
+        public void GetSchemaTableAsync_with_exception()
+            => Assert.ThrowsAsync<NotSupportedException>(async () => await new DbDataReaderMock().GetSchemaTableAsync());
+
+        [Fact]
+        public async Task GetSchemaTableAsync_calls_GetSchemaTable()
+        {
+            var readerTable = new DataTable();
+            readerTable.Columns.Add("text_col", typeof(string));
+
+            var table = await new SchemaDbDataReaderMock(readerTable).GetSchemaTableAsync();
+
+            var textColRow = table.Rows.Cast<DataRow>().Single();
+            Assert.Equal("text_col", textColRow["ColumnName"]);
+            Assert.Same(typeof(string), textColRow["DataType"]);
+        }
+
+        [Fact]
+        public void GetColumnSchemaAsync_with_cancelled_token()
+            => Assert.ThrowsAsync<TaskCanceledException>(async () => await new DbDataReaderMock().GetColumnSchemaAsync(new CancellationToken(true)));
+
+        [Fact]
+        public void GetColumnSchemaAsync_with_exception()
+            => Assert.ThrowsAsync<NotSupportedException>(async () => await new DbDataReaderMock().GetColumnSchemaAsync());
+
+        [Fact]
+        public async Task GetColumnSchemaAsync_calls_GetSchemaTable()
+        {
+            var readerTable = new DataTable();
+            readerTable.Columns.Add("text_col", typeof(string));
+
+            var column = (await new SchemaDbDataReaderMock(readerTable).GetColumnSchemaAsync()).Single();
+            Assert.Equal("text_col", column.ColumnName);
+            Assert.Same(typeof(string), column.DataType);
+        }
+
         private void SkipRows(int rowsToSkip)
         {
             var i = 0;