Add more advanced SafeArray tests (#21063)
authorJeremy Koritzinsky <jkoritzinsky@gmail.com>
Thu, 29 Nov 2018 19:57:36 +0000 (11:57 -0800)
committerGitHub <noreply@github.com>
Thu, 29 Nov 2018 19:57:36 +0000 (11:57 -0800)
* Port simple SafeArray tests over from .NET Framework's test suite.

* Clean up native calling conventions.

* Add basic field array -> field SAFEARRAY* tests.

* Fix some hiding stdcall calling convention decorators.

* PR Feedback.

* Rewrite native part of SafeArray tests

* Wrote the managed side of the tests. Non-passing ones are currently commented out.

* Fix up tests and enable marshalling a string[] as a SAFEARRAY of LPSTR or LPWSTR.

* Add blittable records test with a return parameter.

* Add returning non-blittable record safearray test.

* Fix which typedef we're using

* Clean up

* Fix/cleanup RecordNative

* Replace "" with <> for platformdefines.h include.

* Use macro for calling convention.

* PR Feedback.

* Remove residual "Helpers.h" include.

* Fix remaining incorrect calling conventions.

* One more calling convention fix.

17 files changed:
src/vm/olevariant.cpp
tests/src/Common/Platform/platformdefines.cpp
tests/src/Common/Platform/platformdefines.h
tests/src/Exceptions/ForeignThread/ForeignThreadExceptionsNative.cpp
tests/src/Interop/ArrayMarshalling/ByValArray/MarshalArrayByValNative.cpp
tests/src/Interop/ArrayMarshalling/SafeArray/CMakeLists.txt [new file with mode: 0644]
tests/src/Interop/ArrayMarshalling/SafeArray/RecordNative.cpp [new file with mode: 0644]
tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayNative.cpp [new file with mode: 0644]
tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.cs [new file with mode: 0644]
tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.csproj [new file with mode: 0644]
tests/src/Interop/CMakeLists.txt
tests/src/Interop/FuncPtrAsDelegateParam/FuncPtrAsDelegateParamNative.cpp
tests/src/Interop/SimpleStruct/SimpleStructNative.cpp
tests/src/Interop/StringMarshalling/BSTR/BSTRTestNative.cpp
tests/src/Interop/StructMarshalling/PInvoke/MarshalStructAsParamDLL.h
tests/src/Interop/StructMarshalling/ReversePInvoke/MarshalExpStruct/ExpStructAsParamNative.h
tests/src/Interop/StructMarshalling/ReversePInvoke/MarshalSeqStruct/SeqStructDelRevPInvokeNative.cpp

index 4188a22..b874fc4 100644 (file)
@@ -472,6 +472,10 @@ BOOL OleVariant::IsValidArrayForSafeArrayElementType(BASEARRAYREF *pArrayRef, VA
         case VT_CY:
             return vtActual == VT_DECIMAL;
 
+        case VT_LPSTR:
+        case VT_LPWSTR:
+            return vtActual == VT_BSTR;
+
         default:
             return FALSE;
     }
index 61b91b3..7e21ab0 100644 (file)
@@ -438,7 +438,7 @@ HRESULT CbSysStringSize(ULONG cchSize, BOOL isByteLen, ULONG *result)
     return INTSAFE_E_ARITHMETIC_OVERFLOW;
 }
 
-BSTR TP_SysAllocString(LPWSTR psz)
+BSTR TP_SysAllocString(LPCWSTR psz)
 {
 #ifdef WINDOWS    
     return SysAllocString(psz);
@@ -449,7 +449,7 @@ BSTR TP_SysAllocString(LPWSTR psz)
 #endif
 }
 
-BSTR TP_SysAllocStringLen(LPWSTR psz, size_t len)
+BSTR TP_SysAllocStringLen(LPCWSTR psz, size_t len)
 {
     ULONG cbTotal = 0;
 
@@ -542,3 +542,14 @@ size_t TP_SysStringByteLen(BSTR bstr)
     return (unsigned int)(((DWORD *)bstr)[-1]);
 #endif    
 }
+
+DWORD TP_SysStringLen(BSTR bstr)
+{
+#ifdef WINDOWS
+    return SysStringLen(bstr);
+#else
+    if(bstr == NULL)
+      return 0;
+    return (unsigned int)((((DWORD *)bstr)[-1]) / sizeof(OLECHAR));
+#endif
+}
index e1e84b6..0a33bc1 100644 (file)
@@ -146,8 +146,9 @@ typedef WCHAR* BSTR;
 BSTR TP_SysAllocStringByteLen(LPCSTR psz, size_t len);
 void TP_SysFreeString(BSTR bstr);
 size_t TP_SysStringByteLen(BSTR bstr);
-BSTR TP_SysAllocStringLen(LPWSTR psz, size_t len);
-BSTR TP_SysAllocString(LPWSTR psz);
+BSTR TP_SysAllocStringLen(LPCWSTR psz, size_t len);
+BSTR TP_SysAllocString(LPCWSTR psz);
+DWORD TP_SysStringLen(BSTR bstr);
 
 //
 // Method redirects
index b640597..67a92e3 100644 (file)
@@ -12,7 +12,7 @@
 // Work around typedef redefinition: platformdefines.h defines error_t
 // as unsigned while it's defined as int in errno.h.
 #define error_t error_t_ignore
-#include "platformdefines.h"
+#include <platformdefines.h>
 #undef error_t
 
 typedef void (*PFNACTION1)();
index 9e6be4f..86d726e 100644 (file)
@@ -3,7 +3,7 @@
 // See the LICENSE file in the project root for more information.
 
 #include <xplatform.h>
-#include "platformdefines.h"
+#include <platformdefines.h>
 const int ARRAY_SIZE = 100;
 template<typename T> bool IsObjectEquals(T o1, T o2);
 
diff --git a/tests/src/Interop/ArrayMarshalling/SafeArray/CMakeLists.txt b/tests/src/Interop/ArrayMarshalling/SafeArray/CMakeLists.txt
new file mode 100644 (file)
index 0000000..e5ff0f6
--- /dev/null
@@ -0,0 +1,11 @@
+cmake_minimum_required (VERSION 2.6)
+project (SafeArrayNative)
+include_directories(${INC_PLATFORM_DIR})
+set(SOURCES SafeArrayNative.cpp RecordNative.cpp)
+
+# add the executable
+add_library (SafeArrayNative SHARED ${SOURCES})
+target_link_libraries(SafeArrayNative ${LINK_LIBRARIES_ADDITIONAL}) 
+
+# add the install targets
+install (TARGETS SafeArrayNative DESTINATION bin)
diff --git a/tests/src/Interop/ArrayMarshalling/SafeArray/RecordNative.cpp b/tests/src/Interop/ArrayMarshalling/SafeArray/RecordNative.cpp
new file mode 100644 (file)
index 0000000..2f37965
--- /dev/null
@@ -0,0 +1,400 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#include <xplatform.h>
+#include <oleauto.h>
+#include <algorithm>
+#include <platformdefines.h>
+
+struct BlittableRecord
+{
+    int a;
+};
+
+class BlittableRecordInfo : public IRecordInfo
+{
+public:
+    HRESULT STDMETHODCALLTYPE GetField(PVOID pvData, LPCOLESTR szFieldName, VARIANT* pvarField)
+    {
+        if (pvData == nullptr || pvarField == nullptr)
+        {
+            return E_INVALIDARG;
+        }
+
+        BlittableRecord* pData = (BlittableRecord*)pvData;
+
+        if (wcscmp(szFieldName, W("a")) == 0)
+        {
+            VariantClear(pvarField);
+            V_VT(pvarField) = VT_I4;
+            V_I4(pvarField) = pData->a;
+            return S_OK;
+        }
+        return E_INVALIDARG;
+    }
+
+    HRESULT STDMETHODCALLTYPE GetFieldNames(ULONG* pcNames, BSTR* rgBstrNames)
+    {
+        if (pcNames == nullptr)
+        {
+            return E_INVALIDARG;
+        }
+        if (rgBstrNames == nullptr)
+        {
+            *pcNames = 1;
+            return S_OK;
+        }
+
+        if (*pcNames == 0)
+        {
+            return S_OK;
+        }
+
+        rgBstrNames[0] = TP_SysAllocString(W("a"));
+        
+        for(size_t i = 1; i < *pcNames; i++)
+        {
+            rgBstrNames[i] = nullptr;
+        }
+        
+        return S_OK;
+    }
+
+    HRESULT STDMETHODCALLTYPE GetFieldNoCopy(
+        PVOID     pvData,
+        LPCOLESTR szFieldName,
+        VARIANT   *pvarField,
+        PVOID     *ppvDataCArray
+    )
+    {
+        return E_FAIL;
+    }
+
+    HRESULT STDMETHODCALLTYPE GetGuid(GUID *pguid)
+    {
+        return E_FAIL;
+    }
+
+    HRESULT STDMETHODCALLTYPE GetName(BSTR* pbstrName)
+    {
+        *pbstrName = TP_SysAllocString(W("BlittableRecord"));
+        return S_OK;
+    }
+
+    HRESULT STDMETHODCALLTYPE GetSize(ULONG* pcbSize)
+    {
+        *pcbSize = sizeof(BlittableRecord);
+        return S_OK;
+    }
+
+    HRESULT STDMETHODCALLTYPE GetTypeInfo(ITypeInfo** ppTypeInfo)
+    {
+        return TYPE_E_INVALIDSTATE;
+    }
+
+    BOOL STDMETHODCALLTYPE IsMatchingType(IRecordInfo* pRecordInfo)
+    {
+        return pRecordInfo == this;
+    }
+
+    HRESULT STDMETHODCALLTYPE PutField(
+        ULONG     wFlags,
+        PVOID     pvData,
+        LPCOLESTR szFieldName,
+        VARIANT   *pvarField
+    )
+    {
+        return E_FAIL;
+    }
+
+    HRESULT STDMETHODCALLTYPE PutFieldNoCopy(
+        ULONG     wFlags,
+        PVOID     pvData,
+        LPCOLESTR szFieldName,
+        VARIANT   *pvarField
+    )
+    {
+        return E_FAIL;
+    }
+
+    HRESULT STDMETHODCALLTYPE RecordClear(PVOID pvExisting)
+    {
+        return S_OK;
+    }
+
+    HRESULT STDMETHODCALLTYPE RecordCopy(PVOID pvExisting, PVOID pvNew)
+    {
+        ((BlittableRecord*)pvNew)->a = ((BlittableRecord*)pvExisting)->a;
+        return S_OK;
+    }
+
+    PVOID STDMETHODCALLTYPE RecordCreate()
+    {
+        return CoreClrAlloc(sizeof(BlittableRecord));
+    }
+
+    HRESULT STDMETHODCALLTYPE RecordCreateCopy(
+        PVOID pvSource,
+        PVOID *ppvDest
+    )
+    {
+        *ppvDest = RecordCreate();
+        return RecordCopy(pvSource, *ppvDest);
+    }
+
+    HRESULT STDMETHODCALLTYPE RecordDestroy(PVOID pvRecord)
+    {
+        CoreClrFree(pvRecord);
+        return S_OK;
+    }
+
+    HRESULT STDMETHODCALLTYPE RecordInit(PVOID pvNew)
+    {
+        ((BlittableRecord*)pvNew)->a = 0;
+        return S_OK;
+    }
+
+    ULONG STDMETHODCALLTYPE AddRef()
+    {
+        return ++refCount;
+    }
+
+    ULONG STDMETHODCALLTYPE Release()
+    {
+        return --refCount;
+    }
+
+    HRESULT STDMETHODCALLTYPE QueryInterface(const IID& riid, void** ppvObject)
+    {
+        if (riid == __uuidof(IRecordInfo))
+        {
+            *ppvObject = static_cast<IRecordInfo*>(this);
+        }
+        else if (riid == __uuidof(IUnknown))
+        {
+            *ppvObject = static_cast<IUnknown*>(this);
+        }
+        else
+        {
+            *ppvObject = nullptr;
+            return E_NOINTERFACE;
+        }
+
+        AddRef();
+        return S_OK;
+    }
+
+private:
+    ULONG refCount;
+} s_BlittableRecordInfo;
+
+struct NonBlittableRecord
+{
+    BOOL b;
+};
+
+
+class NonBlittableRecordInfo : public IRecordInfo
+{
+public:
+    HRESULT STDMETHODCALLTYPE GetField(PVOID pvData, LPCOLESTR szFieldName, VARIANT* pvarField)
+    {
+        if (pvData == nullptr || pvarField == nullptr)
+        {
+            return E_INVALIDARG;
+        }
+
+        NonBlittableRecord* pData = (NonBlittableRecord*)pvData;
+
+        if (wcscmp(szFieldName, W("b")) == 0)
+        {
+            VariantClear(pvarField);
+            V_VT(pvarField) = VT_BOOL;
+            V_BOOL(pvarField) = pData->b == TRUE ? VARIANT_TRUE : VARIANT_FALSE;
+            return S_OK;
+        }
+        return E_INVALIDARG;
+    }
+
+    HRESULT STDMETHODCALLTYPE GetFieldNames(ULONG* pcNames, BSTR* rgBstrNames)
+    {
+        if (pcNames == nullptr)
+        {
+            return E_INVALIDARG;
+        }
+        if (rgBstrNames == nullptr)
+        {
+            *pcNames = 1;
+            return S_OK;
+        }
+
+        if (*pcNames == 0)
+        {
+            return S_OK;
+        }
+
+        rgBstrNames[0] = TP_SysAllocString(W("b"));
+        
+        for(size_t i = 1; i < *pcNames; i++)
+        {
+            rgBstrNames[i] = nullptr;
+        }
+        
+        return S_OK;
+    }
+
+    HRESULT STDMETHODCALLTYPE GetFieldNoCopy(
+        PVOID     pvData,
+        LPCOLESTR szFieldName,
+        VARIANT   *pvarField,
+        PVOID     *ppvDataCArray
+    )
+    {
+        return E_FAIL;
+    }
+
+    HRESULT STDMETHODCALLTYPE GetGuid(GUID *pguid)
+    {
+        return E_FAIL;
+    }
+
+    HRESULT STDMETHODCALLTYPE GetName(BSTR* pbstrName)
+    {
+        *pbstrName = TP_SysAllocString(W("NonBlittableRecord"));
+        return S_OK;
+    }
+
+    HRESULT STDMETHODCALLTYPE GetSize(ULONG* pcbSize)
+    {
+        *pcbSize = sizeof(BlittableRecord);
+        return S_OK;
+    }
+
+    HRESULT STDMETHODCALLTYPE GetTypeInfo(ITypeInfo** ppTypeInfo)
+    {
+        return TYPE_E_INVALIDSTATE;
+    }
+
+    BOOL STDMETHODCALLTYPE IsMatchingType(IRecordInfo* pRecordInfo)
+    {
+        return pRecordInfo == this;
+    }
+
+    HRESULT STDMETHODCALLTYPE PutField(
+        ULONG     wFlags,
+        PVOID     pvData,
+        LPCOLESTR szFieldName,
+        VARIANT   *pvarField
+    )
+    {
+        return E_FAIL;
+    }
+
+    HRESULT STDMETHODCALLTYPE PutFieldNoCopy(
+        ULONG     wFlags,
+        PVOID     pvData,
+        LPCOLESTR szFieldName,
+        VARIANT   *pvarField
+    )
+    {
+        return E_FAIL;
+    }
+
+    HRESULT STDMETHODCALLTYPE RecordClear(PVOID pvExisting)
+    {
+        return S_OK;
+    }
+
+    HRESULT STDMETHODCALLTYPE RecordCopy(PVOID pvExisting, PVOID pvNew)
+    {
+        ((NonBlittableRecord*)pvNew)->b = ((NonBlittableRecord*)pvExisting)->b;
+        return S_OK;
+    }
+
+    PVOID STDMETHODCALLTYPE RecordCreate()
+    {
+        return CoreClrAlloc(sizeof(NonBlittableRecord));
+    }
+
+    HRESULT STDMETHODCALLTYPE RecordCreateCopy(
+        PVOID pvSource,
+        PVOID *ppvDest
+    )
+    {
+        *ppvDest = RecordCreate();
+        return RecordCopy(pvSource, *ppvDest);
+    }
+
+    HRESULT STDMETHODCALLTYPE RecordDestroy(PVOID pvRecord)
+    {
+        CoreClrFree(pvRecord);
+        return S_OK;
+    }
+
+    HRESULT STDMETHODCALLTYPE RecordInit(PVOID pvNew)
+    {
+        ((NonBlittableRecord*)pvNew)->b = FALSE;
+        return S_OK;
+    }
+
+    ULONG STDMETHODCALLTYPE AddRef()
+    {
+        return ++refCount;
+    }
+
+    ULONG STDMETHODCALLTYPE Release()
+    {
+        return --refCount;
+    }
+
+    HRESULT STDMETHODCALLTYPE QueryInterface(const IID& riid, void** ppvObject)
+    {
+        if (riid == __uuidof(IRecordInfo))
+        {
+            *ppvObject = static_cast<IRecordInfo*>(this);
+        }
+        else if (riid == __uuidof(IUnknown))
+        {
+            *ppvObject = static_cast<IUnknown*>(this);
+        }
+        else
+        {
+            *ppvObject = nullptr;
+            return E_NOINTERFACE;
+        }
+
+        AddRef();
+        return S_OK;
+    }
+
+private:
+    ULONG refCount;
+} s_NonBlittableRecordInfo;
+
+extern "C" DLL_EXPORT SAFEARRAY* STDMETHODCALLTYPE CreateSafeArrayOfRecords(BlittableRecord records[], int numRecords)
+{
+    SAFEARRAYBOUND bounds[1] = {
+        {numRecords, 0}
+    };
+
+    SAFEARRAY* arr = SafeArrayCreateEx(VT_RECORD, 1, bounds, &s_BlittableRecordInfo);
+
+    memcpy(arr->pvData, records, numRecords * sizeof(BlittableRecord));
+
+    return arr;
+}
+
+
+extern "C" DLL_EXPORT SAFEARRAY* STDMETHODCALLTYPE CreateSafeArrayOfNonBlittableRecords(NonBlittableRecord records[], int numRecords)
+{
+    SAFEARRAYBOUND bounds[1] = {
+        {numRecords, 0}
+    };
+
+    SAFEARRAY* arr = SafeArrayCreateEx(VT_RECORD, 1, bounds, &s_NonBlittableRecordInfo);
+
+    memcpy(arr->pvData, records, numRecords * sizeof(NonBlittableRecord));
+
+    return arr;
+}
diff --git a/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayNative.cpp b/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayNative.cpp
new file mode 100644 (file)
index 0000000..76f4b54
--- /dev/null
@@ -0,0 +1,313 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#include <xplatform.h>
+#include <oleauto.h>
+#include <algorithm>
+#include <platformdefines.h>
+
+#define RETURN_IF_FAILED(x) if(FAILED(hr = (x))) { return hr; }
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE XorBoolArray(SAFEARRAY* d, BOOL* result)
+{
+    HRESULT hr;
+    *result = FALSE;
+    VARTYPE elementType;
+    RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType));
+
+    if (elementType != VT_BOOL)
+    {
+        return E_INVALIDARG;
+    }
+
+    LONG lowerBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex));
+    LONG upperBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex));
+
+    VARIANT_BOOL* values;
+    RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values));
+    
+    for(long i = lowerBoundIndex; i <= upperBoundIndex; i++)
+    {
+        *result ^= values[i] == VARIANT_TRUE ? TRUE : FALSE;
+    }
+
+    RETURN_IF_FAILED(::SafeArrayUnaccessData(d));
+    
+    return S_OK;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE MeanDecimalArray(SAFEARRAY* d, DECIMAL* result)
+{
+    HRESULT hr;
+    DECIMAL sum{};
+    DECIMAL_SETZERO(sum);
+
+    VARTYPE elementType;
+    RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType));
+
+    if (elementType != VT_DECIMAL)
+    {
+        return E_INVALIDARG;
+    }
+
+    LONG lowerBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex));
+    LONG upperBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex));
+
+    DECIMAL* values;
+    RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values));
+    
+    for(long i = lowerBoundIndex; i <= upperBoundIndex; i++)
+    {
+        DECIMAL lhs = sum;
+        VarDecAdd(&lhs, &values[i], &sum);
+    }
+
+    DECIMAL numElements;
+    VarDecFromI4(upperBoundIndex - lowerBoundIndex + 1, &numElements);
+
+    VarDecDiv(&sum, &numElements, result);
+
+    RETURN_IF_FAILED(::SafeArrayUnaccessData(d));
+    
+    return S_OK;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE SumCurrencyArray(SAFEARRAY* d, CY* result)
+{
+    HRESULT hr;
+    CY sum{};
+    VARTYPE elementType;
+    RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType));
+
+    if (elementType != VT_CY)
+    {
+        return E_INVALIDARG;
+    }
+
+    LONG lowerBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex));
+    LONG upperBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex));
+
+    CY* values;
+    RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values));
+    
+    for(long i = lowerBoundIndex; i <= upperBoundIndex; i++)
+    {
+        CY lhs = sum;
+        VarCyAdd(lhs, values[i], &sum);
+    }
+
+    *result = sum;
+
+    RETURN_IF_FAILED(::SafeArrayUnaccessData(d));
+    
+    return S_OK;
+}
+
+template <typename StringType>
+StringType ReverseInplace(size_t len, StringType s)
+{
+    std::reverse(s, s + len);
+    return s;
+}
+
+template<typename StringType>
+HRESULT Reverse(StringType str, StringType *res)
+{
+    StringType tmp = str;
+    size_t len = 0;
+    while (*tmp++)
+        ++len;
+
+    size_t strDataLen = (len + 1) * sizeof(str[0]);
+    auto resLocal = (StringType)CoreClrAlloc(strDataLen);
+    if (resLocal == nullptr)
+        return E_INVALIDARG;
+
+    memcpy(resLocal, str, strDataLen);
+    *res = ReverseInplace(len, resLocal);
+
+    return S_OK;
+}
+
+HRESULT ReverseBSTR(BSTR str, BSTR *res)
+{
+    size_t strDataLen = TP_SysStringByteLen(str);
+    BSTR resLocal = TP_SysAllocStringByteLen(reinterpret_cast<LPCSTR>(str), strDataLen);
+    if (resLocal == nullptr)
+        return E_INVALIDARG;
+
+    UINT len = TP_SysStringLen(str);
+    *res = ReverseInplace(len, resLocal);
+
+    return S_OK;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE ReverseStrings(SAFEARRAY* d)
+{
+    HRESULT hr;
+    VARTYPE elementType;
+    RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType));
+
+    if (elementType != VT_LPSTR && elementType != VT_LPWSTR && elementType != VT_BSTR)
+    {
+        return E_INVALIDARG;
+    }
+
+    LONG lowerBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex));
+    LONG upperBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex));
+
+    void** values;
+    RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values));
+    
+    for(long i = lowerBoundIndex; i <= upperBoundIndex; i++)
+    {
+        if (elementType == VT_LPSTR)
+        {
+            LPSTR reversed;
+            RETURN_IF_FAILED(Reverse((LPSTR)values[i], &reversed));
+            values[i] = reversed;
+        }
+        else if (elementType == VT_LPWSTR)
+        {
+            LPWSTR reversed;
+            RETURN_IF_FAILED(Reverse((LPWSTR)values[i], &reversed));
+            values[i] = reversed;
+        }
+        else if (elementType == VT_BSTR)
+        {
+            BSTR reversed;
+            RETURN_IF_FAILED(ReverseBSTR((BSTR)values[i], &reversed));
+            values[i] = reversed;
+        }
+    }
+
+    RETURN_IF_FAILED(::SafeArrayUnaccessData(d));
+    
+    return S_OK;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE VerifyInterfaceArray(SAFEARRAY* d, VARTYPE expectedType)
+{
+    HRESULT hr;
+    VARTYPE elementType;
+    RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType));
+
+    if (elementType != expectedType)
+    {
+        return E_INVALIDARG;
+    }
+
+    LONG lowerBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex));
+    LONG upperBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex));
+
+    IUnknown** values;
+    RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values));
+    
+    for(long i = lowerBoundIndex; i <= upperBoundIndex; i++)
+    {
+        values[i]->AddRef();
+        values[i]->Release();
+    }
+
+    RETURN_IF_FAILED(::SafeArrayUnaccessData(d));
+    
+    return S_OK;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE MeanVariantIntArray(SAFEARRAY* d, int* result)
+{
+    HRESULT hr;
+    *result = 0;
+    VARTYPE elementType;
+    RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType));
+
+    if (elementType != VT_VARIANT)
+    {
+        return E_INVALIDARG;
+    }
+
+    LONG lowerBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex));
+    LONG upperBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex));
+
+    VARIANT* values;
+    RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values));
+    
+    for(long i = lowerBoundIndex; i <= upperBoundIndex; i++)
+    {
+        if (values[i].vt != VT_I4)
+        {
+            RETURN_IF_FAILED(::SafeArrayUnaccessData(d));
+            return E_INVALIDARG;
+        }
+        
+        *result += values[i].intVal;
+    }
+
+    *result /= upperBoundIndex - lowerBoundIndex + 1;
+
+    RETURN_IF_FAILED(::SafeArrayUnaccessData(d));
+    
+    return S_OK;
+}
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE DistanceBetweenDates(SAFEARRAY* d, double* result)
+{
+    HRESULT hr;
+    *result = 0;
+    VARTYPE elementType;
+    RETURN_IF_FAILED(::SafeArrayGetVartype(d, &elementType));
+
+    if (elementType != VT_DATE)
+    {
+        return E_INVALIDARG;
+    }
+
+    LONG lowerBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetLBound(d, 1, &lowerBoundIndex));
+    LONG upperBoundIndex;
+    RETURN_IF_FAILED(::SafeArrayGetUBound(d, 1, &upperBoundIndex));
+
+    DATE* values;
+    RETURN_IF_FAILED(::SafeArrayAccessData(d, (void**)&values));
+    
+    bool haveLastValue = false;
+    DATE lastValue;
+
+    for(long i = lowerBoundIndex; i <= upperBoundIndex; i++)
+    {
+        if (haveLastValue)
+        {
+            *result += values[i] - lastValue;
+        }
+
+        lastValue = values[i];
+        haveLastValue = true;
+    }
+
+    RETURN_IF_FAILED(::SafeArrayUnaccessData(d));
+    
+    return S_OK;
+}
+
+struct StructWithSafeArray
+{
+    SAFEARRAY* array;
+};
+
+extern "C" DLL_EXPORT HRESULT STDMETHODCALLTYPE XorBoolArrayInStruct(StructWithSafeArray str, BOOL* result)
+{
+    return XorBoolArray(str.array, result);
+}
diff --git a/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.cs b/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.cs
new file mode 100644 (file)
index 0000000..59030d6
--- /dev/null
@@ -0,0 +1,219 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Linq;
+using System.Threading;
+using System.Runtime.InteropServices;
+using TestLibrary;
+
+#pragma warning disable CS0612, CS0618
+
+public class Tester
+{
+    public static int Main()
+    {
+        try
+        {
+            var boolArray = new bool[] { true, false, true, false, false, true };
+            SafeArrayNative.XorBoolArray(boolArray, out var xorResult);
+            Assert.AreEqual(XorArray(boolArray), xorResult);
+
+            var decimalArray = new decimal[] { 1.5M, 30.2M, 6432M, 12.5832M };
+            SafeArrayNative.MeanDecimalArray(decimalArray, out var meanDecimalValue);
+            Assert.AreEqual(decimalArray.Average(), meanDecimalValue);
+
+            SafeArrayNative.SumCurrencyArray(decimalArray, out var sumCurrencyValue);
+            Assert.AreEqual(decimalArray.Sum(), sumCurrencyValue);
+
+            var strings = new [] {"ABCDE", "12345", "Microsoft"};
+            var reversedStrings = strings.Select(str => Reverse(str)).ToArray();
+
+            var ansiTest = strings.ToArray();
+            SafeArrayNative.ReverseStringsAnsi(ansiTest);
+            Assert.AreAllEqual(reversedStrings, ansiTest);
+
+            var unicodeTest = strings.ToArray();
+            SafeArrayNative.ReverseStringsUnicode(unicodeTest);
+            Assert.AreAllEqual(reversedStrings, unicodeTest);
+
+            var bstrTest = strings.ToArray();
+            SafeArrayNative.ReverseStringsBSTR(bstrTest);
+            Assert.AreAllEqual(reversedStrings, bstrTest);
+
+            var blittableRecords = new SafeArrayNative.BlittableRecord[]
+            {
+                new SafeArrayNative.BlittableRecord { a = 1 },
+                new SafeArrayNative.BlittableRecord { a = 5 },
+                new SafeArrayNative.BlittableRecord { a = 7 },
+                new SafeArrayNative.BlittableRecord { a = 3 },
+                new SafeArrayNative.BlittableRecord { a = 9 },
+                new SafeArrayNative.BlittableRecord { a = 15 },
+            };
+            Assert.AreAllEqual(blittableRecords, SafeArrayNative.CreateSafeArrayOfRecords(blittableRecords));
+
+            var nonBlittableRecords = boolArray.Select(b => new SafeArrayNative.NonBlittableRecord{ b = b }).ToArray();
+            Assert.AreAllEqual(nonBlittableRecords, SafeArrayNative.CreateSafeArrayOfRecords(nonBlittableRecords));
+
+            var objects = new object[] { new object(), new object(), new object() };
+            SafeArrayNative.VerifyIUnknownArray(objects);
+            SafeArrayNative.VerifyIDispatchArray(objects);
+
+            var variantInts = new object[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
+            
+            SafeArrayNative.MeanVariantIntArray(variantInts, out var variantMean);
+            Assert.AreEqual(variantInts.OfType<int>().Average(), variantMean);
+
+            var dates = new DateTime[] { new DateTime(2008, 5, 1), new DateTime(2010, 1, 1) };
+            SafeArrayNative.DistanceBetweenDates(dates, out var numDays);
+            Assert.AreEqual((dates[1] - dates[0]).TotalDays, numDays);
+
+            SafeArrayNative.XorBoolArrayInStruct(
+                new SafeArrayNative.StructWithSafeArray
+                {
+                    values = boolArray
+                },
+                out var structXor);
+
+            Assert.AreEqual(XorArray(boolArray), structXor);
+        }
+        catch (Exception e)
+        {
+            Console.WriteLine(e);
+            return 101;
+        }
+        return 100;
+    }
+
+    private static bool XorArray(bool[] values)
+    {
+        bool retVal = false;
+        foreach (var item in values)
+        {
+            retVal ^= item;
+        }
+        return retVal;
+    }
+
+    private static string Reverse(string s)
+    {
+        var chars = s.ToCharArray();
+        Array.Reverse(chars);
+        return new string(chars);
+    }
+}
+
+class SafeArrayNative
+{
+    public struct StructWithSafeArray
+    {
+        [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_BOOL)]
+        public bool[] values;
+    }
+
+    public struct BlittableRecord
+    {
+        public int a;
+    }
+
+    public struct NonBlittableRecord
+    {
+        public bool b;
+    }
+
+    [DllImport(nameof(SafeArrayNative))]
+    [return: MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_RECORD)]
+    private static extern  BlittableRecord[] CreateSafeArrayOfRecords(
+        BlittableRecord[] records,
+        int numElements
+    );
+
+    public static BlittableRecord[] CreateSafeArrayOfRecords(BlittableRecord[] records)
+    {
+        return CreateSafeArrayOfRecords(records, records.Length);
+    }
+
+    [DllImport(nameof(SafeArrayNative), EntryPoint = "CreateSafeArrayOfNonBlittableRecords")]
+    [return: MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_RECORD)]
+    private static extern NonBlittableRecord[] CreateSafeArrayOfRecords(
+        NonBlittableRecord[] records,
+        int numElements
+    );
+
+    public static NonBlittableRecord[] CreateSafeArrayOfRecords(NonBlittableRecord[] records)
+    {
+        return CreateSafeArrayOfRecords(records, records.Length);
+    }
+
+    [DllImport(nameof(SafeArrayNative), PreserveSig = false)]
+    public static extern void XorBoolArray(
+        [MarshalAs(UnmanagedType.SafeArray)] bool[] values,
+        out bool result
+    );
+
+    [DllImport(nameof(SafeArrayNative), PreserveSig = false)]
+    public static extern void MeanDecimalArray(
+        [MarshalAs(UnmanagedType.SafeArray)] decimal[] values,
+        out decimal result
+    );
+
+    [DllImport(nameof(SafeArrayNative), PreserveSig = false)]
+    public static extern void SumCurrencyArray(
+        [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_CY)] decimal[] values,
+        [MarshalAs(UnmanagedType.Currency)] out decimal result
+    );
+
+    [DllImport(nameof(SafeArrayNative), PreserveSig = false, EntryPoint = "ReverseStrings")]
+    public static extern void ReverseStringsAnsi(
+        [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_LPSTR), In, Out] string[] strings
+    );
+
+    [DllImport(nameof(SafeArrayNative), PreserveSig = false, EntryPoint = "ReverseStrings")]
+    public static extern void ReverseStringsUnicode(
+        [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_LPWSTR), In, Out] string[] strings
+    );
+
+    [DllImport(nameof(SafeArrayNative), PreserveSig = false, EntryPoint = "ReverseStrings")]
+    public static extern void ReverseStringsBSTR(
+        [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_BSTR), In, Out] string[] strings
+    );
+
+    [DllImport(nameof(SafeArrayNative), PreserveSig = false, EntryPoint = "VerifyInterfaceArray")]
+    private static extern void VerifyInterfaceArrayIUnknown(
+        [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_UNKNOWN)] object[] objects,
+        short expectedVarType
+    );
+
+    [DllImport(nameof(SafeArrayNative), PreserveSig = false, EntryPoint = "VerifyInterfaceArray")]
+    private static extern void VerifyInterfaceArrayIDispatch(
+        [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_DISPATCH)] object[] objects,
+        short expectedVarType
+    );
+
+    public static void VerifyIUnknownArray(object[] objects)
+    {
+        VerifyInterfaceArrayIUnknown(objects, (short)VarEnum.VT_UNKNOWN);
+    }
+
+    public static void VerifyIDispatchArray(object[] objects)
+    {
+        VerifyInterfaceArrayIDispatch(objects, (short)VarEnum.VT_DISPATCH);
+    }
+
+    [DllImport(nameof(SafeArrayNative), PreserveSig = false)]
+    public static extern void MeanVariantIntArray(
+        [MarshalAs(UnmanagedType.SafeArray, SafeArraySubType = VarEnum.VT_VARIANT)]
+        object[] objects,
+        out int result
+    );
+
+    [DllImport(nameof(SafeArrayNative), PreserveSig = false)]
+    public static extern void DistanceBetweenDates(
+        [MarshalAs(UnmanagedType.SafeArray)] DateTime[] dates,
+        out double result
+    );
+
+    [DllImport(nameof(SafeArrayNative), PreserveSig = false)]
+    public static extern void XorBoolArrayInStruct(StructWithSafeArray str, out bool result);
+}
diff --git a/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.csproj b/tests/src/Interop/ArrayMarshalling/SafeArray/SafeArrayTest.csproj
new file mode 100644 (file)
index 0000000..8234f78
--- /dev/null
@@ -0,0 +1,43 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+  <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.props))\dir.props" />
+  <PropertyGroup>
+    <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
+    <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
+    <AssemblyName>SafeArrayTest</AssemblyName>
+    <SchemaVersion>2.0</SchemaVersion>
+    <ProjectGuid>{F1E66554-8C8E-4141-85CF-D0CD6A0CD0B0}</ProjectGuid>
+    <OutputType>Exe</OutputType>
+    <ProjectTypeGuids>{786C830F-07A1-408B-BD7F-6EE04809D6DB};{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}</ProjectTypeGuids>
+    <SolutionDir Condition="$(SolutionDir) == '' Or $(SolutionDir) == '*Undefined*'">..\..\</SolutionDir>
+
+    <DefineConstants>$(DefineConstants);STATIC</DefineConstants>
+    <!-- Test unsupported outside of windows -->
+    <TestUnsupportedOutsideWindows>true</TestUnsupportedOutsideWindows>
+    <DisableProjectBuild Condition="'$(TargetsUnix)' == 'true'">true</DisableProjectBuild>
+  </PropertyGroup>
+  <!-- Default configurations to help VS understand the configurations -->
+  <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|x64'">
+  </PropertyGroup>
+  <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|x64'">
+  </PropertyGroup>
+  <ItemGroup>
+    <CodeAnalysisDependentAssemblyPaths Condition=" '$(VS100COMNTOOLS)' != '' " Include="$(VS100COMNTOOLS)..\IDE\PrivateAssemblies">
+      <Visible>False</Visible>
+    </CodeAnalysisDependentAssemblyPaths>
+  </ItemGroup>
+  <ItemGroup>
+    <Compile Include="*.cs" />
+  </ItemGroup>
+  <ItemGroup>
+    <Service Include="{82A7F48D-3B50-4B1E-B82E-3ADA8210C358}" />
+  </ItemGroup>
+  <ItemGroup>
+    <ProjectReference Include="..\..\..\Common\CoreCLRTestLibrary\CoreCLRTestLibrary.csproj">
+      <Project>{c8c0dc74-fac4-45b1-81fe-70c4808366e0}</Project>
+      <Name>CoreCLRTestLibrary</Name>
+    </ProjectReference>
+    <ProjectReference Include="CMakeLists.txt" />
+  </ItemGroup>
+  <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" />
+</Project>
index a1a55e7..54e20a8 100644 (file)
@@ -71,6 +71,7 @@ if(WIN32)
     add_subdirectory(PInvoke/NativeCallManagedComVisible)
     # This test doesn't necessarily need to be Windows-only, but the implementation is very tied to Windows APIs
     add_subdirectory(PInvoke/DateTime)
+    add_subdirectory(ArrayMarshalling/SafeArray)
     add_subdirectory(COM/NativeServer)
     add_subdirectory(COM/NativeClients/Primitives)
     add_subdirectory(IJW/FakeMscoree)
index 73e9b91..63219ad 100644 (file)
@@ -1,4 +1,4 @@
-#include "platformdefines.h"
+#include <platformdefines.h>
 
 #include <stdio.h>
 #include <stdlib.h>
index 26f78c7..0365167 100644 (file)
@@ -2,7 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
-#include "platformdefines.h"
+#include <platformdefines.h>
 
 WCHAR strManaged[] = W("Managed\0String\0");
 size_t lenstrManaged = sizeof(strManaged) - sizeof(WCHAR);
index 08f069c..28779a7 100644 (file)
@@ -3,7 +3,7 @@
 // See the LICENSE file in the project root for more information.
 
 #include <xplatform.h>
-#include "platformdefines.h"
+#include <platformdefines.h>
 #include "SeqStructDelRevPInvokeNative.h"
 
 const char* NativeStr = "Native";