From 3eac0693fab69a4b02f7d61dd474c37798f65608 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Thu, 29 Nov 2018 11:57:36 -0800 Subject: [PATCH] Add more advanced SafeArray tests (dotnet/coreclr#21063) * Port simple SafeArray tests over from .NET Framework's test suite. * Clean up native calling conventions. * Add basic field array -> field SAFEARRAY* tests. * Fix some hiding stdcall calling convention decorators. * PR Feedback. * Rewrite native part of SafeArray tests * Wrote the managed side of the tests. Non-passing ones are currently commented out. * Fix up tests and enable marshalling a string[] as a SAFEARRAY of LPSTR or LPWSTR. * Add blittable records test with a return parameter. * Add returning non-blittable record safearray test. * Fix which typedef we're using * Clean up * Fix/cleanup RecordNative * Replace "" with <> for platformdefines.h include. * Use macro for calling convention. * PR Feedback. * Remove residual "Helpers.h" include. * Fix remaining incorrect calling conventions. * One more calling convention fix. Commit migrated from https://github.com/dotnet/coreclr/commit/93d3f4c1f7f8a27c94aad68837dc8dda30c19b02 --- src/coreclr/src/vm/olevariant.cpp | 4 + .../tests/src/Common/Platform/platformdefines.cpp | 15 +- .../tests/src/Common/Platform/platformdefines.h | 5 +- .../ForeignThreadExceptionsNative.cpp | 2 +- .../ByValArray/MarshalArrayByValNative.cpp | 2 +- .../ArrayMarshalling/SafeArray/CMakeLists.txt | 11 + .../ArrayMarshalling/SafeArray/RecordNative.cpp | 400 +++++++++++++++++++++ .../ArrayMarshalling/SafeArray/SafeArrayNative.cpp | 313 ++++++++++++++++ .../ArrayMarshalling/SafeArray/SafeArrayTest.cs | 219 +++++++++++ .../SafeArray/SafeArrayTest.csproj | 43 +++ src/coreclr/tests/src/Interop/CMakeLists.txt | 1 + .../FuncPtrAsDelegateParamNative.cpp | 2 +- .../Interop/SimpleStruct/SimpleStructNative.cpp | 2 +- .../StringMarshalling/BSTR/BSTRTestNative.cpp | 2 +- .../PInvoke/MarshalStructAsParamDLL.h | 2 +- .../MarshalExpStruct/ExpStructAsParamNative.h | 2 +- .../SeqStructDelRevPInvokeNative.cpp | 2 +- 17 files changed, 1015 insertions(+), 12 deletions(-) create mode 100644 src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/CMakeLists.txt create mode 100644 src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/RecordNative.cpp create mode 100644 src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayNative.cpp create mode 100644 src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.cs create mode 100644 src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.csproj diff --git a/src/coreclr/src/vm/olevariant.cpp b/src/coreclr/src/vm/olevariant.cpp index 4188a22..b874fc4 100644 --- a/src/coreclr/src/vm/olevariant.cpp +++ b/src/coreclr/src/vm/olevariant.cpp @@ -472,6 +472,10 @@ BOOL OleVariant::IsValidArrayForSafeArrayElementType(BASEARRAYREF *pArrayRef, VA case VT_CY: return vtActual == VT_DECIMAL; + case VT_LPSTR: + case VT_LPWSTR: + return vtActual == VT_BSTR; + default: return FALSE; } diff --git a/src/coreclr/tests/src/Common/Platform/platformdefines.cpp b/src/coreclr/tests/src/Common/Platform/platformdefines.cpp index 61b91b3..7e21ab0 100644 --- a/src/coreclr/tests/src/Common/Platform/platformdefines.cpp +++ b/src/coreclr/tests/src/Common/Platform/platformdefines.cpp @@ -438,7 +438,7 @@ HRESULT CbSysStringSize(ULONG cchSize, BOOL isByteLen, ULONG *result) return INTSAFE_E_ARITHMETIC_OVERFLOW; } -BSTR TP_SysAllocString(LPWSTR psz) +BSTR TP_SysAllocString(LPCWSTR psz) { #ifdef WINDOWS return SysAllocString(psz); @@ -449,7 +449,7 @@ BSTR TP_SysAllocString(LPWSTR psz) #endif } -BSTR TP_SysAllocStringLen(LPWSTR psz, size_t len) +BSTR TP_SysAllocStringLen(LPCWSTR psz, size_t len) { ULONG cbTotal = 0; @@ -542,3 +542,14 @@ size_t TP_SysStringByteLen(BSTR bstr) return (unsigned int)(((DWORD *)bstr)[-1]); #endif } + +DWORD TP_SysStringLen(BSTR bstr) +{ +#ifdef WINDOWS + return SysStringLen(bstr); +#else + if(bstr == NULL) + return 0; + return (unsigned int)((((DWORD *)bstr)[-1]) / sizeof(OLECHAR)); +#endif +} diff --git a/src/coreclr/tests/src/Common/Platform/platformdefines.h b/src/coreclr/tests/src/Common/Platform/platformdefines.h index e1e84b6..0a33bc1 100644 --- a/src/coreclr/tests/src/Common/Platform/platformdefines.h +++ b/src/coreclr/tests/src/Common/Platform/platformdefines.h @@ -146,8 +146,9 @@ typedef WCHAR* BSTR; BSTR TP_SysAllocStringByteLen(LPCSTR psz, size_t len); void TP_SysFreeString(BSTR bstr); size_t TP_SysStringByteLen(BSTR bstr); -BSTR TP_SysAllocStringLen(LPWSTR psz, size_t len); -BSTR TP_SysAllocString(LPWSTR psz); +BSTR TP_SysAllocStringLen(LPCWSTR psz, size_t len); +BSTR TP_SysAllocString(LPCWSTR psz); +DWORD TP_SysStringLen(BSTR bstr); // // Method redirects diff --git a/src/coreclr/tests/src/Exceptions/ForeignThread/ForeignThreadExceptionsNative.cpp b/src/coreclr/tests/src/Exceptions/ForeignThread/ForeignThreadExceptionsNative.cpp index b640597..67a92e3 100644 --- a/src/coreclr/tests/src/Exceptions/ForeignThread/ForeignThreadExceptionsNative.cpp +++ b/src/coreclr/tests/src/Exceptions/ForeignThread/ForeignThreadExceptionsNative.cpp @@ -12,7 +12,7 @@ // Work around typedef redefinition: platformdefines.h defines error_t // as unsigned while it's defined as int in errno.h. #define error_t error_t_ignore -#include "platformdefines.h" +#include #undef error_t typedef void (*PFNACTION1)(); diff --git a/src/coreclr/tests/src/Interop/ArrayMarshalling/ByValArray/MarshalArrayByValNative.cpp b/src/coreclr/tests/src/Interop/ArrayMarshalling/ByValArray/MarshalArrayByValNative.cpp index 9e6be4f..86d726e 100644 --- a/src/coreclr/tests/src/Interop/ArrayMarshalling/ByValArray/MarshalArrayByValNative.cpp +++ b/src/coreclr/tests/src/Interop/ArrayMarshalling/ByValArray/MarshalArrayByValNative.cpp @@ -3,7 +3,7 @@ // See the LICENSE file in the project root for more information. #include -#include "platformdefines.h" +#include const int ARRAY_SIZE = 100; template bool IsObjectEquals(T o1, T o2); diff --git a/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/CMakeLists.txt b/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/CMakeLists.txt new file mode 100644 index 0000000..e5ff0f68 --- /dev/null +++ b/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required (VERSION 2.6) +project (SafeArrayNative) +include_directories(${INC_PLATFORM_DIR}) +set(SOURCES SafeArrayNative.cpp RecordNative.cpp) + +# add the executable +add_library (SafeArrayNative SHARED ${SOURCES}) +target_link_libraries(SafeArrayNative ${LINK_LIBRARIES_ADDITIONAL}) + +# add the install targets +install (TARGETS SafeArrayNative DESTINATION bin) diff --git a/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/RecordNative.cpp b/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/RecordNative.cpp new file mode 100644 index 0000000..2f37965 --- /dev/null +++ b/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/RecordNative.cpp @@ -0,0 +1,400 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#include +#include +#include +#include + +struct BlittableRecord +{ + int a; +}; + +class BlittableRecordInfo : public IRecordInfo +{ +public: + HRESULT STDMETHODCALLTYPE GetField(PVOID pvData, LPCOLESTR szFieldName, VARIANT* pvarField) + { + if (pvData == nullptr || pvarField == nullptr) + { + return E_INVALIDARG; + } + + BlittableRecord* pData = (BlittableRecord*)pvData; + + if (wcscmp(szFieldName, W("a")) == 0) + { + VariantClear(pvarField); + V_VT(pvarField) = VT_I4; + V_I4(pvarField) = pData->a; + return S_OK; + } + return E_INVALIDARG; + } + + HRESULT STDMETHODCALLTYPE GetFieldNames(ULONG* pcNames, BSTR* rgBstrNames) + { + if (pcNames == nullptr) + { + return E_INVALIDARG; + } + if (rgBstrNames == nullptr) + { + *pcNames = 1; + return S_OK; + } + + if (*pcNames == 0) + { + return S_OK; + } + + rgBstrNames[0] = TP_SysAllocString(W("a")); + + for(size_t i = 1; i < *pcNames; i++) + { + rgBstrNames[i] = nullptr; + } + + return S_OK; + } + + HRESULT STDMETHODCALLTYPE GetFieldNoCopy( + PVOID pvData, + LPCOLESTR szFieldName, + VARIANT *pvarField, + PVOID *ppvDataCArray + ) + { + return E_FAIL; + } + + HRESULT STDMETHODCALLTYPE GetGuid(GUID *pguid) + { + return E_FAIL; + } + + HRESULT STDMETHODCALLTYPE GetName(BSTR* pbstrName) + { + *pbstrName = TP_SysAllocString(W("BlittableRecord")); + return S_OK; + } + + HRESULT STDMETHODCALLTYPE GetSize(ULONG* pcbSize) + { + *pcbSize = sizeof(BlittableRecord); + return S_OK; + } + + HRESULT STDMETHODCALLTYPE GetTypeInfo(ITypeInfo** ppTypeInfo) + { + return TYPE_E_INVALIDSTATE; + } + + BOOL STDMETHODCALLTYPE IsMatchingType(IRecordInfo* pRecordInfo) + { + return pRecordInfo == this; + } + + HRESULT STDMETHODCALLTYPE PutField( + ULONG wFlags, + PVOID pvData, + LPCOLESTR szFieldName, + VARIANT *pvarField + ) + { + return E_FAIL; + } + + HRESULT STDMETHODCALLTYPE PutFieldNoCopy( + ULONG wFlags, + PVOID pvData, + LPCOLESTR szFieldName, + VARIANT *pvarField + ) + { + return E_FAIL; + } + + HRESULT STDMETHODCALLTYPE RecordClear(PVOID pvExisting) + { + return S_OK; + } + + HRESULT STDMETHODCALLTYPE RecordCopy(PVOID pvExisting, PVOID pvNew) + { + ((BlittableRecord*)pvNew)->a = ((BlittableRecord*)pvExisting)->a; + return S_OK; + } + + PVOID STDMETHODCALLTYPE RecordCreate() + { + return CoreClrAlloc(sizeof(BlittableRecord)); + } + + HRESULT STDMETHODCALLTYPE RecordCreateCopy( + PVOID pvSource, + PVOID *ppvDest + ) + { + *ppvDest = RecordCreate(); + return RecordCopy(pvSource, *ppvDest); + } + + HRESULT STDMETHODCALLTYPE RecordDestroy(PVOID pvRecord) + { + CoreClrFree(pvRecord); + return S_OK; + } + + HRESULT STDMETHODCALLTYPE RecordInit(PVOID pvNew) + { + ((BlittableRecord*)pvNew)->a = 0; + return S_OK; + } + + ULONG STDMETHODCALLTYPE AddRef() + { + return ++refCount; + } + + ULONG STDMETHODCALLTYPE Release() + { + return --refCount; + } + + HRESULT STDMETHODCALLTYPE QueryInterface(const IID& riid, void** ppvObject) + { + if (riid == __uuidof(IRecordInfo)) + { + *ppvObject = static_cast(this); + } + else if (riid == __uuidof(IUnknown)) + { + *ppvObject = static_cast(this); + } + else + { + *ppvObject = nullptr; + return E_NOINTERFACE; + } + + AddRef(); + return S_OK; + } + +private: + ULONG refCount; +} s_BlittableRecordInfo; + +struct NonBlittableRecord +{ + BOOL b; +}; + + +class NonBlittableRecordInfo : public IRecordInfo +{ +public: + HRESULT STDMETHODCALLTYPE GetField(PVOID pvData, LPCOLESTR szFieldName, VARIANT* pvarField) + { + if (pvData == nullptr || pvarField == nullptr) + { + return E_INVALIDARG; + } + + NonBlittableRecord* pData = (NonBlittableRecord*)pvData; + + if (wcscmp(szFieldName, W("b")) == 0) + { + VariantClear(pvarField); + V_VT(pvarField) = VT_BOOL; + V_BOOL(pvarField) = pData->b == TRUE ? VARIANT_TRUE : VARIANT_FALSE; + return S_OK; + } + return E_INVALIDARG; + } + + HRESULT STDMETHODCALLTYPE GetFieldNames(ULONG* pcNames, BSTR* rgBstrNames) + { + if (pcNames == nullptr) + { + return E_INVALIDARG; + } + if (rgBstrNames == nullptr) + { + *pcNames = 1; + return S_OK; + } + + if (*pcNames == 0) + { + return S_OK; + } + + rgBstrNames[0] = TP_SysAllocString(W("b")); + + for(size_t i = 1; i < *pcNames; i++) + { + rgBstrNames[i] = nullptr; + } + + return S_OK; + } + + HRESULT STDMETHODCALLTYPE GetFieldNoCopy( + PVOID pvData, + LPCOLESTR szFieldName, + VARIANT *pvarField, + PVOID *ppvDataCArray + ) + { + return E_FAIL; + } + + HRESULT STDMETHODCALLTYPE GetGuid(GUID *pguid) + { + return E_FAIL; + } + + HRESULT STDMETHODCALLTYPE GetName(BSTR* pbstrName) + { + *pbstrName = TP_SysAllocString(W("NonBlittableRecord")); + return S_OK; + } + + HRESULT STDMETHODCALLTYPE GetSize(ULONG* pcbSize) + { + *pcbSize = sizeof(BlittableRecord); + return S_OK; + } + + HRESULT STDMETHODCALLTYPE GetTypeInfo(ITypeInfo** ppTypeInfo) + { + return TYPE_E_INVALIDSTATE; + } + + BOOL STDMETHODCALLTYPE IsMatchingType(IRecordInfo* pRecordInfo) + { + return pRecordInfo == this; + } + + HRESULT STDMETHODCALLTYPE PutField( + ULONG wFlags, + PVOID pvData, + LPCOLESTR szFieldName, + VARIANT *pvarField + ) + { + return E_FAIL; + } + + HRESULT STDMETHODCALLTYPE PutFieldNoCopy( + ULONG wFlags, + PVOID pvData, + LPCOLESTR szFieldName, + VARIANT *pvarField + ) + { + return E_FAIL; + } + + HRESULT STDMETHODCALLTYPE RecordClear(PVOID pvExisting) + { + return S_OK; + } + + HRESULT STDMETHODCALLTYPE RecordCopy(PVOID pvExisting, PVOID pvNew) + { + ((NonBlittableRecord*)pvNew)->b = ((NonBlittableRecord*)pvExisting)->b; + return S_OK; + } + + PVOID STDMETHODCALLTYPE RecordCreate() + { + return CoreClrAlloc(sizeof(NonBlittableRecord)); + } + + HRESULT STDMETHODCALLTYPE RecordCreateCopy( + PVOID pvSource, + PVOID *ppvDest + ) + { + *ppvDest = RecordCreate(); + return RecordCopy(pvSource, *ppvDest); + } + + HRESULT STDMETHODCALLTYPE RecordDestroy(PVOID pvRecord) + { + CoreClrFree(pvRecord); + return S_OK; + } + + HRESULT STDMETHODCALLTYPE RecordInit(PVOID pvNew) + { + ((NonBlittableRecord*)pvNew)->b = FALSE; + return S_OK; + } + + ULONG STDMETHODCALLTYPE AddRef() + { + return ++refCount; + } + + ULONG STDMETHODCALLTYPE Release() + { + return --refCount; + } + + HRESULT STDMETHODCALLTYPE QueryInterface(const IID& riid, void** ppvObject) + { + if (riid == __uuidof(IRecordInfo)) + { + *ppvObject = static_cast(this); + } + else if (riid == __uuidof(IUnknown)) + { + *ppvObject = static_cast(this); + } + else + { + *ppvObject = nullptr; + return E_NOINTERFACE; + } + + AddRef(); + return S_OK; + } + +private: + ULONG refCount; +} s_NonBlittableRecordInfo; + +extern "C" DLL_EXPORT SAFEARRAY* STDMETHODCALLTYPE CreateSafeArrayOfRecords(BlittableRecord records[], int numRecords) +{ + SAFEARRAYBOUND bounds[1] = { + {numRecords, 0} + }; + + SAFEARRAY* arr = SafeArrayCreateEx(VT_RECORD, 1, bounds, &s_BlittableRecordInfo); + + memcpy(arr->pvData, records, numRecords * sizeof(BlittableRecord)); + + return arr; +} + + +extern "C" DLL_EXPORT SAFEARRAY* STDMETHODCALLTYPE CreateSafeArrayOfNonBlittableRecords(NonBlittableRecord records[], int numRecords) +{ + SAFEARRAYBOUND bounds[1] = { + {numRecords, 0} + }; + + SAFEARRAY* arr = SafeArrayCreateEx(VT_RECORD, 1, bounds, &s_NonBlittableRecordInfo); + + memcpy(arr->pvData, records, numRecords * sizeof(NonBlittableRecord)); + + return arr; +} diff --git a/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayNative.cpp b/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayNative.cpp new file mode 100644 index 0000000..76f4b54 --- /dev/null +++ b/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayNative.cpp @@ -0,0 +1,313 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#include +#include +#include +#include + +#define RETURN_IF_FAILED(x) if(FAILED(hr = (x))) { return hr; } + +extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE XorBoolArray(SAFEARRAY* d, BOOL* result) +{ + HRESULT hr; + *result = FALSE; + VARTYPE elementType; + RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType)); + + if (elementType != VT_BOOL) + { + return E_INVALIDARG; + } + + LONG lowerBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex)); + LONG upperBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex)); + + VARIANT_BOOL* values; + RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values)); + + for(long i = lowerBoundIndex; i <= upperBoundIndex; i++) + { + *result ^= values[i] == VARIANT_TRUE ? TRUE : FALSE; + } + + RETURN_IF_FAILED(::SafeArrayUnaccessData(d)); + + return S_OK; +} + +extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE MeanDecimalArray(SAFEARRAY* d, DECIMAL* result) +{ + HRESULT hr; + DECIMAL sum{}; + DECIMAL_SETZERO(sum); + + VARTYPE elementType; + RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType)); + + if (elementType != VT_DECIMAL) + { + return E_INVALIDARG; + } + + LONG lowerBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex)); + LONG upperBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex)); + + DECIMAL* values; + RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values)); + + for(long i = lowerBoundIndex; i <= upperBoundIndex; i++) + { + DECIMAL lhs = sum; + VarDecAdd(&lhs, &values[i], &sum); + } + + DECIMAL numElements; + VarDecFromI4(upperBoundIndex - lowerBoundIndex + 1, &numElements); + + VarDecDiv(&sum, &numElements, result); + + RETURN_IF_FAILED(::SafeArrayUnaccessData(d)); + + return S_OK; +} + +extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE SumCurrencyArray(SAFEARRAY* d, CY* result) +{ + HRESULT hr; + CY sum{}; + VARTYPE elementType; + RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType)); + + if (elementType != VT_CY) + { + return E_INVALIDARG; + } + + LONG lowerBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex)); + LONG upperBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex)); + + CY* values; + RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values)); + + for(long i = lowerBoundIndex; i <= upperBoundIndex; i++) + { + CY lhs = sum; + VarCyAdd(lhs, values[i], &sum); + } + + *result = sum; + + RETURN_IF_FAILED(::SafeArrayUnaccessData(d)); + + return S_OK; +} + +template +StringType ReverseInplace(size_t len, StringType s) +{ + std::reverse(s, s + len); + return s; +} + +template +HRESULT Reverse(StringType str, StringType *res) +{ + StringType tmp = str; + size_t len = 0; + while (*tmp++) + ++len; + + size_t strDataLen = (len + 1) * sizeof(str[0]); + auto resLocal = (StringType)CoreClrAlloc(strDataLen); + if (resLocal == nullptr) + return E_INVALIDARG; + + memcpy(resLocal, str, strDataLen); + *res = ReverseInplace(len, resLocal); + + return S_OK; +} + +HRESULT ReverseBSTR(BSTR str, BSTR *res) +{ + size_t strDataLen = TP_SysStringByteLen(str); + BSTR resLocal = TP_SysAllocStringByteLen(reinterpret_cast(str), strDataLen); + if (resLocal == nullptr) + return E_INVALIDARG; + + UINT len = TP_SysStringLen(str); + *res = ReverseInplace(len, resLocal); + + return S_OK; +} + +extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE ReverseStrings(SAFEARRAY* d) +{ + HRESULT hr; + VARTYPE elementType; + RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType)); + + if (elementType != VT_LPSTR && elementType != VT_LPWSTR && elementType != VT_BSTR) + { + return E_INVALIDARG; + } + + LONG lowerBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex)); + LONG upperBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex)); + + void** values; + RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values)); + + for(long i = lowerBoundIndex; i <= upperBoundIndex; i++) + { + if (elementType == VT_LPSTR) + { + LPSTR reversed; + RETURN_IF_FAILED(Reverse((LPSTR)values[i], &reversed)); + values[i] = reversed; + } + else if (elementType == VT_LPWSTR) + { + LPWSTR reversed; + RETURN_IF_FAILED(Reverse((LPWSTR)values[i], &reversed)); + values[i] = reversed; + } + else if (elementType == VT_BSTR) + { + BSTR reversed; + RETURN_IF_FAILED(ReverseBSTR((BSTR)values[i], &reversed)); + values[i] = reversed; + } + } + + RETURN_IF_FAILED(::SafeArrayUnaccessData(d)); + + return S_OK; +} + +extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE VerifyInterfaceArray(SAFEARRAY* d, VARTYPE expectedType) +{ + HRESULT hr; + VARTYPE elementType; + RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType)); + + if (elementType != expectedType) + { + return E_INVALIDARG; + } + + LONG lowerBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex)); + LONG upperBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex)); + + IUnknown** values; + RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values)); + + for(long i = lowerBoundIndex; i <= upperBoundIndex; i++) + { + values[i]->AddRef(); + values[i]->Release(); + } + + RETURN_IF_FAILED(::SafeArrayUnaccessData(d)); + + return S_OK; +} + +extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE MeanVariantIntArray(SAFEARRAY* d, int* result) +{ + HRESULT hr; + *result = 0; + VARTYPE elementType; + RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType)); + + if (elementType != VT_VARIANT) + { + return E_INVALIDARG; + } + + LONG lowerBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex)); + LONG upperBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex)); + + VARIANT* values; + RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values)); + + for(long i = lowerBoundIndex; i <= upperBoundIndex; i++) + { + if (values[i].vt != VT_I4) + { + RETURN_IF_FAILED(::SafeArrayUnaccessData(d)); + return E_INVALIDARG; + } + + *result += values[i].intVal; + } + + *result /= upperBoundIndex - lowerBoundIndex + 1; + + RETURN_IF_FAILED(::SafeArrayUnaccessData(d)); + + return S_OK; +} + +extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE DistanceBetweenDates(SAFEARRAY* d, double* result) +{ + HRESULT hr; + *result = 0; + VARTYPE elementType; + RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType)); + + if (elementType != VT_DATE) + { + return E_INVALIDARG; + } + + LONG lowerBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex)); + LONG upperBoundIndex; + RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex)); + + DATE* values; + RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values)); + + bool haveLastValue = false; + DATE lastValue; + + for(long i = lowerBoundIndex; i <= upperBoundIndex; i++) + { + if (haveLastValue) + { + *result += values[i] - lastValue; + } + + lastValue = values[i]; + haveLastValue = true; + } + + RETURN_IF_FAILED(::SafeArrayUnaccessData(d)); + + return S_OK; +} + +struct StructWithSafeArray +{ + SAFEARRAY* array; +}; + +extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE XorBoolArrayInStruct(StructWithSafeArray str, BOOL* result) +{ + return XorBoolArray(str.array, result); +} diff --git a/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.cs b/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.cs new file mode 100644 index 0000000..59030d6 --- /dev/null +++ b/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.cs @@ -0,0 +1,219 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using System.Threading; +using System.Runtime.InteropServices; +using TestLibrary; + +#pragma warning disable CS0612, CS0618 + +public class Tester +{ + public static int Main() + { + try + { + var boolArray = new bool[] { true, false, true, false, false, true }; + SafeArrayNative.XorBoolArray(boolArray, out var xorResult); + Assert.AreEqual(XorArray(boolArray), xorResult); + + var decimalArray = new decimal[] { 1.5M, 30.2M, 6432M, 12.5832M }; + SafeArrayNative.MeanDecimalArray(decimalArray, out var meanDecimalValue); + Assert.AreEqual(decimalArray.Average(), meanDecimalValue); + + SafeArrayNative.SumCurrencyArray(decimalArray, out var sumCurrencyValue); + Assert.AreEqual(decimalArray.Sum(), sumCurrencyValue); + + var strings = new [] {"ABCDE", "12345", "Microsoft"}; + var reversedStrings = strings.Select(str => Reverse(str)).ToArray(); + + var ansiTest = strings.ToArray(); + SafeArrayNative.ReverseStringsAnsi(ansiTest); + Assert.AreAllEqual(reversedStrings, ansiTest); + + var unicodeTest = strings.ToArray(); + SafeArrayNative.ReverseStringsUnicode(unicodeTest); + Assert.AreAllEqual(reversedStrings, unicodeTest); + + var bstrTest = strings.ToArray(); + SafeArrayNative.ReverseStringsBSTR(bstrTest); + Assert.AreAllEqual(reversedStrings, bstrTest); + + var blittableRecords = new SafeArrayNative.BlittableRecord[] + { + new SafeArrayNative.BlittableRecord { a = 1 }, + new SafeArrayNative.BlittableRecord { a = 5 }, + new SafeArrayNative.BlittableRecord { a = 7 }, + new SafeArrayNative.BlittableRecord { a = 3 }, + new SafeArrayNative.BlittableRecord { a = 9 }, + new SafeArrayNative.BlittableRecord { a = 15 }, + }; + Assert.AreAllEqual(blittableRecords, SafeArrayNative.CreateSafeArrayOfRecords(blittableRecords)); + + var nonBlittableRecords = boolArray.Select(b => new SafeArrayNative.NonBlittableRecord{ b = b }).ToArray(); + Assert.AreAllEqual(nonBlittableRecords, SafeArrayNative.CreateSafeArrayOfRecords(nonBlittableRecords)); + + var objects = new object[] { new object(), new object(), new object() }; + SafeArrayNative.VerifyIUnknownArray(objects); + SafeArrayNative.VerifyIDispatchArray(objects); + + var variantInts = new object[] {1, 2, 3, 4, 5, 6, 7, 8, 9}; + + SafeArrayNative.MeanVariantIntArray(variantInts, out var variantMean); + Assert.AreEqual(variantInts.OfType().Average(), variantMean); + + var dates = new DateTime[] { new DateTime(2008, 5, 1), new DateTime(2010, 1, 1) }; + SafeArrayNative.DistanceBetweenDates(dates, out var numDays); + Assert.AreEqual((dates[1] - dates[0]).TotalDays, numDays); + + SafeArrayNative.XorBoolArrayInStruct( + new SafeArrayNative.StructWithSafeArray + { + values = boolArray + }, + out var structXor); + + Assert.AreEqual(XorArray(boolArray), structXor); + } + catch (Exception e) + { + Console.WriteLine(e); + return 101; + } + return 100; + } + + private static bool XorArray(bool[] values) + { + bool retVal = false; + foreach (var item in values) + { + retVal ^= item; + } + return retVal; + } + + private static string Reverse(string s) + { + var chars = s.ToCharArray(); + Array.Reverse(chars); + return new string(chars); + } +} + +class SafeArrayNative +{ + public struct StructWithSafeArray + { + [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_BOOL)] + public bool[] values; + } + + public struct BlittableRecord + { + public int a; + } + + public struct NonBlittableRecord + { + public bool b; + } + + [DllImport(nameof(SafeArrayNative))] + [return: MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_RECORD)] + private static extern BlittableRecord[] CreateSafeArrayOfRecords( + BlittableRecord[] records, + int numElements + ); + + public static BlittableRecord[] CreateSafeArrayOfRecords(BlittableRecord[] records) + { + return CreateSafeArrayOfRecords(records, records.Length); + } + + [DllImport(nameof(SafeArrayNative), EntryPoint = "CreateSafeArrayOfNonBlittableRecords")] + [return: MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_RECORD)] + private static extern NonBlittableRecord[] CreateSafeArrayOfRecords( + NonBlittableRecord[] records, + int numElements + ); + + public static NonBlittableRecord[] CreateSafeArrayOfRecords(NonBlittableRecord[] records) + { + return CreateSafeArrayOfRecords(records, records.Length); + } + + [DllImport(nameof(SafeArrayNative), PreserveSig = false)] + public static extern void XorBoolArray( + [MarshalAs(UnmanagedType.SafeArray)] bool[] values, + out bool result + ); + + [DllImport(nameof(SafeArrayNative), PreserveSig = false)] + public static extern void MeanDecimalArray( + [MarshalAs(UnmanagedType.SafeArray)] decimal[] values, + out decimal result + ); + + [DllImport(nameof(SafeArrayNative), PreserveSig = false)] + public static extern void SumCurrencyArray( + [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_CY)] decimal[] values, + [MarshalAs(UnmanagedType.Currency)] out decimal result + ); + + [DllImport(nameof(SafeArrayNative), PreserveSig = false, EntryPoint = "ReverseStrings")] + public static extern void ReverseStringsAnsi( + [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_LPSTR), In, Out] string[] strings + ); + + [DllImport(nameof(SafeArrayNative), PreserveSig = false, EntryPoint = "ReverseStrings")] + public static extern void ReverseStringsUnicode( + [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_LPWSTR), In, Out] string[] strings + ); + + [DllImport(nameof(SafeArrayNative), PreserveSig = false, EntryPoint = "ReverseStrings")] + public static extern void ReverseStringsBSTR( + [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_BSTR), In, Out] string[] strings + ); + + [DllImport(nameof(SafeArrayNative), PreserveSig = false, EntryPoint = "VerifyInterfaceArray")] + private static extern void VerifyInterfaceArrayIUnknown( + [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_UNKNOWN)] object[] objects, + short expectedVarType + ); + + [DllImport(nameof(SafeArrayNative), PreserveSig = false, EntryPoint = "VerifyInterfaceArray")] + private static extern void VerifyInterfaceArrayIDispatch( + [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_DISPATCH)] object[] objects, + short expectedVarType + ); + + public static void VerifyIUnknownArray(object[] objects) + { + VerifyInterfaceArrayIUnknown(objects, (short)VarEnum.VT_UNKNOWN); + } + + public static void VerifyIDispatchArray(object[] objects) + { + VerifyInterfaceArrayIDispatch(objects, (short)VarEnum.VT_DISPATCH); + } + + [DllImport(nameof(SafeArrayNative), PreserveSig = false)] + public static extern void MeanVariantIntArray( + [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_VARIANT)] + object[] objects, + out int result + ); + + [DllImport(nameof(SafeArrayNative), PreserveSig = false)] + public static extern void DistanceBetweenDates( + [MarshalAs(UnmanagedType.SafeArray)] DateTime[] dates, + out double result + ); + + [DllImport(nameof(SafeArrayNative), PreserveSig = false)] + public static extern void XorBoolArrayInStruct(StructWithSafeArray str, out bool result); +} diff --git a/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.csproj b/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.csproj new file mode 100644 index 0000000..8234f78 --- /dev/null +++ b/src/coreclr/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.csproj @@ -0,0 +1,43 @@ + + + + + Debug + AnyCPU + SafeArrayTest + 2.0 + {F1E66554-8C8E-4141-85CF-D0CD6A0CD0B0} + Exe + {786C830F-07A1-408B-BD7F-6EE04809D6DB};{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC} + ..\..\ + + $(DefineConstants);STATIC + + true + true + + + + + + + + + False + + + + + + + + + + + {c8c0dc74-fac4-45b1-81fe-70c4808366e0} + CoreCLRTestLibrary + + + + + diff --git a/src/coreclr/tests/src/Interop/CMakeLists.txt b/src/coreclr/tests/src/Interop/CMakeLists.txt index a1a55e7..54e20a8 100644 --- a/src/coreclr/tests/src/Interop/CMakeLists.txt +++ b/src/coreclr/tests/src/Interop/CMakeLists.txt @@ -71,6 +71,7 @@ if(WIN32) add_subdirectory(PInvoke/NativeCallManagedComVisible) # This test doesn't necessarily need to be Windows-only, but the implementation is very tied to Windows APIs add_subdirectory(PInvoke/DateTime) + add_subdirectory(ArrayMarshalling/SafeArray) add_subdirectory(COM/NativeServer) add_subdirectory(COM/NativeClients/Primitives) add_subdirectory(IJW/FakeMscoree) diff --git a/src/coreclr/tests/src/Interop/FuncPtrAsDelegateParam/FuncPtrAsDelegateParamNative.cpp b/src/coreclr/tests/src/Interop/FuncPtrAsDelegateParam/FuncPtrAsDelegateParamNative.cpp index 72617d1..02575b0 100644 --- a/src/coreclr/tests/src/Interop/FuncPtrAsDelegateParam/FuncPtrAsDelegateParamNative.cpp +++ b/src/coreclr/tests/src/Interop/FuncPtrAsDelegateParam/FuncPtrAsDelegateParamNative.cpp @@ -1,4 +1,4 @@ -#include "platformdefines.h" +#include #include #include #include diff --git a/src/coreclr/tests/src/Interop/SimpleStruct/SimpleStructNative.cpp b/src/coreclr/tests/src/Interop/SimpleStruct/SimpleStructNative.cpp index 73e9b91..63219ad 100644 --- a/src/coreclr/tests/src/Interop/SimpleStruct/SimpleStructNative.cpp +++ b/src/coreclr/tests/src/Interop/SimpleStruct/SimpleStructNative.cpp @@ -1,4 +1,4 @@ -#include "platformdefines.h" +#include #include #include diff --git a/src/coreclr/tests/src/Interop/StringMarshalling/BSTR/BSTRTestNative.cpp b/src/coreclr/tests/src/Interop/StringMarshalling/BSTR/BSTRTestNative.cpp index 26f78c7..0365167 100644 --- a/src/coreclr/tests/src/Interop/StringMarshalling/BSTR/BSTRTestNative.cpp +++ b/src/coreclr/tests/src/Interop/StringMarshalling/BSTR/BSTRTestNative.cpp @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -#include "platformdefines.h" +#include WCHAR strManaged[] = W("Managed\0String\0"); size_t lenstrManaged = sizeof(strManaged) - sizeof(WCHAR); diff --git a/src/coreclr/tests/src/Interop/StructMarshalling/PInvoke/MarshalStructAsParamDLL.h b/src/coreclr/tests/src/Interop/StructMarshalling/PInvoke/MarshalStructAsParamDLL.h index befe246..6cb379f 100644 --- a/src/coreclr/tests/src/Interop/StructMarshalling/PInvoke/MarshalStructAsParamDLL.h +++ b/src/coreclr/tests/src/Interop/StructMarshalling/PInvoke/MarshalStructAsParamDLL.h @@ -1,4 +1,4 @@ -#include "platformdefines.h" +#include #include inline char* CoStrDup(const char* str) diff --git a/src/coreclr/tests/src/Interop/StructMarshalling/ReversePInvoke/MarshalExpStruct/ExpStructAsParamNative.h b/src/coreclr/tests/src/Interop/StructMarshalling/ReversePInvoke/MarshalExpStruct/ExpStructAsParamNative.h index cde3b52..36eb913 100644 --- a/src/coreclr/tests/src/Interop/StructMarshalling/ReversePInvoke/MarshalExpStruct/ExpStructAsParamNative.h +++ b/src/coreclr/tests/src/Interop/StructMarshalling/ReversePInvoke/MarshalExpStruct/ExpStructAsParamNative.h @@ -6,7 +6,7 @@ #include #include #include -#include "platformdefines.h" +#include const int NumArrElements = 2; struct InnerSequential diff --git a/src/coreclr/tests/src/Interop/StructMarshalling/ReversePInvoke/MarshalSeqStruct/SeqStructDelRevPInvokeNative.cpp b/src/coreclr/tests/src/Interop/StructMarshalling/ReversePInvoke/MarshalSeqStruct/SeqStructDelRevPInvokeNative.cpp index 08f069c..28779a7 100644 --- a/src/coreclr/tests/src/Interop/StructMarshalling/ReversePInvoke/MarshalSeqStruct/SeqStructDelRevPInvokeNative.cpp +++ b/src/coreclr/tests/src/Interop/StructMarshalling/ReversePInvoke/MarshalSeqStruct/SeqStructDelRevPInvokeNative.cpp @@ -3,7 +3,7 @@ // See the LICENSE file in the project root for more information. #include -#include "platformdefines.h" +#include #include "SeqStructDelRevPInvokeNative.h" const char* NativeStr = "Native"; -- 2.7.4