From c6aeba6a672f30a2d335742a01ad735b2319e4f3 Mon Sep 17 00:00:00 2001 From: Norbert Federa Date: Mon, 30 May 2016 17:54:59 +0200 Subject: [PATCH] winpr/wtsapi: fixed race conditions and tests --- winpr/include/winpr/wtsapi.h | 1 + winpr/libwinpr/wtsapi/CMakeLists.txt | 3 +- .../wtsapi/test/TestWtsApiEnumerateProcesses.c | 26 ++++++-- .../wtsapi/test/TestWtsApiEnumerateSessions.c | 21 ++++-- .../test/TestWtsApiQuerySessionInformation.c | 54 +++++++++------ .../wtsapi/test/TestWtsApiSessionNotification.c | 36 ++++++++-- .../wtsapi/test/TestWtsApiShutdownSystem.c | 12 +++- .../wtsapi/test/TestWtsApiWaitSystemEvent.c | 11 ++- winpr/libwinpr/wtsapi/wtsapi.c | 78 ++++++++++++++++------ winpr/libwinpr/wtsapi/wtsapi.h | 44 ------------ winpr/libwinpr/wtsapi/wtsapi_win32.c | 76 ++++++++++++++++----- winpr/libwinpr/wtsapi/wtsapi_win32.h | 2 +- 12 files changed, 243 insertions(+), 121 deletions(-) delete mode 100644 winpr/libwinpr/wtsapi/wtsapi.h diff --git a/winpr/include/winpr/wtsapi.h b/winpr/include/winpr/wtsapi.h index a623a33..cc50b3e 100644 --- a/winpr/include/winpr/wtsapi.h +++ b/winpr/include/winpr/wtsapi.h @@ -1389,6 +1389,7 @@ extern "C" { WINPR_API BOOL WTSRegisterWtsApiFunctionTable(PWtsApiFunctionTable table); WINPR_API const CHAR* WTSErrorToString(UINT error); +WINPR_API const CHAR* WTSSessionStateToString(WTS_CONNECTSTATE_CLASS state); #ifdef __cplusplus } diff --git a/winpr/libwinpr/wtsapi/CMakeLists.txt b/winpr/libwinpr/wtsapi/CMakeLists.txt index 71c1646..1d8a403 100644 --- a/winpr/libwinpr/wtsapi/CMakeLists.txt +++ b/winpr/libwinpr/wtsapi/CMakeLists.txt @@ -15,8 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -winpr_module_add(wtsapi.c - wtsapi.h) +winpr_module_add(wtsapi.c) if(WIN32) winpr_module_add(wtsapi_win32.c wtsapi_win32.h) diff --git a/winpr/libwinpr/wtsapi/test/TestWtsApiEnumerateProcesses.c b/winpr/libwinpr/wtsapi/test/TestWtsApiEnumerateProcesses.c index ea60ac7..64f7ba0 100644 --- a/winpr/libwinpr/wtsapi/test/TestWtsApiEnumerateProcesses.c +++ b/winpr/libwinpr/wtsapi/test/TestWtsApiEnumerateProcesses.c @@ -2,26 +2,44 @@ #include #include #include +#include int TestWtsApiEnumerateProcesses(int argc, char* argv[]) { DWORD count; BOOL bSuccess; HANDLE hServer; - PWTS_PROCESS_INFO pProcessInfo; + PWTS_PROCESS_INFOA pProcessInfo; + +#ifndef _WIN32 + if (!GetEnvironmentVariableA("WTSAPI_LIBRARY", NULL, 0)) + { + printf("%s: No RDS environment detected, skipping test\n", __FUNCTION__); + return 0; + } +#endif hServer = WTS_CURRENT_SERVER_HANDLE; count = 0; pProcessInfo = NULL; - bSuccess = WTSEnumerateProcesses(hServer, 0, 1, &pProcessInfo, &count); + bSuccess = WTSEnumerateProcessesA(hServer, 0, 1, &pProcessInfo, &count); if (!bSuccess) { - printf("WTSEnumerateProcesses failed: %d\n", (int) GetLastError()); - //return -1; + printf("WTSEnumerateProcesses failed: %u\n", GetLastError()); + return -1; + } + +#if 0 + { + DWORD i; + printf("WTSEnumerateProcesses enumerated %u processs:\n", count); + for (i = 0; i < count; i++) + printf("\t[%u]: %s (%lu)\n", i, pProcessInfo[i].pProcessName, pProcessInfo[i].ProcessId); } +#endif WTSFreeMemory(pProcessInfo); diff --git a/winpr/libwinpr/wtsapi/test/TestWtsApiEnumerateSessions.c b/winpr/libwinpr/wtsapi/test/TestWtsApiEnumerateSessions.c index b52e10b..bec10bd 100644 --- a/winpr/libwinpr/wtsapi/test/TestWtsApiEnumerateSessions.c +++ b/winpr/libwinpr/wtsapi/test/TestWtsApiEnumerateSessions.c @@ -2,6 +2,7 @@ #include #include #include +#include int TestWtsApiEnumerateSessions(int argc, char* argv[]) { @@ -9,14 +10,22 @@ int TestWtsApiEnumerateSessions(int argc, char* argv[]) DWORD count; BOOL bSuccess; HANDLE hServer; - PWTS_SESSION_INFO pSessionInfo; + PWTS_SESSION_INFOA pSessionInfo; + +#ifndef _WIN32 + if (!GetEnvironmentVariableA("WTSAPI_LIBRARY", NULL, 0)) + { + printf("%s: No RDS environment detected, skipping test\n", __FUNCTION__); + return 0; + } +#endif hServer = WTS_CURRENT_SERVER_HANDLE; count = 0; pSessionInfo = NULL; - bSuccess = WTSEnumerateSessions(hServer, 0, 1, &pSessionInfo, &count); + bSuccess = WTSEnumerateSessionsA(hServer, 0, 1, &pSessionInfo, &count); if (!bSuccess) { @@ -28,9 +37,11 @@ int TestWtsApiEnumerateSessions(int argc, char* argv[]) for (index = 0; index < count; index++) { - printf("[%d] SessionId: %d State: %d\n", (int) index, - (int) pSessionInfo[index].SessionId, - (int) pSessionInfo[index].State); + printf("[%u] SessionId: %u WinstationName: '%s' State: %s (%u)\n", index, + pSessionInfo[index].SessionId, + pSessionInfo[index].pWinStationName, + WTSSessionStateToString(pSessionInfo[index].State), + pSessionInfo[index].State); } WTSFreeMemory(pSessionInfo); diff --git a/winpr/libwinpr/wtsapi/test/TestWtsApiQuerySessionInformation.c b/winpr/libwinpr/wtsapi/test/TestWtsApiQuerySessionInformation.c index 7212869..8351806 100644 --- a/winpr/libwinpr/wtsapi/test/TestWtsApiQuerySessionInformation.c +++ b/winpr/libwinpr/wtsapi/test/TestWtsApiQuerySessionInformation.c @@ -2,24 +2,33 @@ #include #include #include +#include int TestWtsApiQuerySessionInformation(int argc, char* argv[]) { - DWORD index; + DWORD index, i; DWORD count; BOOL bSuccess; HANDLE hServer; LPSTR pBuffer; DWORD sessionId; DWORD bytesReturned; - PWTS_SESSION_INFO pSessionInfo; + PWTS_SESSION_INFOA pSessionInfo; + +#ifndef _WIN32 + if (!GetEnvironmentVariableA("WTSAPI_LIBRARY", NULL, 0)) + { + printf("%s: No RDS environment detected, skipping test\n", __FUNCTION__); + return 0; + } +#endif hServer = WTS_CURRENT_SERVER_HANDLE; count = 0; pSessionInfo = NULL; - bSuccess = WTSEnumerateSessions(hServer, 0, 1, &pSessionInfo, &count); + bSuccess = WTSEnumerateSessionsA(hServer, 0, 1, &pSessionInfo, &count); if (!bSuccess) { @@ -47,9 +56,12 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) sessionId = pSessionInfo[index].SessionId; - printf("[%d] SessionId: %d State: %d\n", (int) index, - (int) pSessionInfo[index].SessionId, - (int) pSessionInfo[index].State); + printf("[%u] SessionId: %u State: %s (%u) WinstationName: '%s'\n", + index, + pSessionInfo[index].SessionId, + WTSSessionStateToString(pSessionInfo[index].State), + pSessionInfo[index].State, + pSessionInfo[index].pWinStationName); /* WTSUserName */ @@ -62,7 +74,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) } Username = (char*) pBuffer; - printf("\tWTSUserName: %s\n", Username); + printf("\tWTSUserName: '%s'\n", Username); /* WTSDomainName */ @@ -75,7 +87,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) } Domain = (char*) pBuffer; - printf("\tWTSDomainName: %s\n", Domain); + printf("\tWTSDomainName: '%s'\n", Domain); /* WTSConnectState */ @@ -88,7 +100,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) } ConnectState = *((WTS_CONNECTSTATE_CLASS*) pBuffer); - printf("\tWTSConnectState: %d\n", (int) ConnectState); + printf("\tWTSConnectState: %u (%s)\n", ConnectState, WTSSessionStateToString(ConnectState)); /* WTSClientBuildNumber */ @@ -101,7 +113,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) } ClientBuildNumber = *((ULONG*) pBuffer); - printf("\tWTSClientBuildNumber: %d\n", (int) ClientBuildNumber); + printf("\tWTSClientBuildNumber: %u\n", ClientBuildNumber); /* WTSClientName */ @@ -114,7 +126,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) } ClientName = (char*) pBuffer; - printf("\tWTSClientName: %s\n", ClientName); + printf("\tWTSClientName: '%s'\n", ClientName); /* WTSClientProductId */ @@ -127,7 +139,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) } ClientProductId = *((USHORT*) pBuffer); - printf("\tWTSClientProductId: %d\n", (int) ClientProductId); + printf("\tWTSClientProductId: %u\n", ClientProductId); /* WTSClientHardwareId */ @@ -140,7 +152,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) } ClientHardwareId = *((ULONG*) pBuffer); - printf("\tWTSClientHardwareId: %d\n", (int) ClientHardwareId); + printf("\tWTSClientHardwareId: %u\n", ClientHardwareId); /* WTSClientAddress */ @@ -153,8 +165,12 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) } ClientAddress = (PWTS_CLIENT_ADDRESS) pBuffer; - printf("\tWTSClientAddress: AddressFamily: %d\n", - (int) ClientAddress->AddressFamily); + printf("\tWTSClientAddress: AddressFamily: %u Address: ", + ClientAddress->AddressFamily); + for (i = 0; i < sizeof(ClientAddress->Address); i++) + printf("%02X", ClientAddress->Address[i]); + printf("\n"); + /* WTSClientDisplay */ @@ -167,9 +183,9 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) } ClientDisplay = (PWTS_CLIENT_DISPLAY) pBuffer; - printf("\tWTSClientDisplay: HorizontalResolution: %d VerticalResolution: %d ColorDepth: %d\n", - (int) ClientDisplay->HorizontalResolution, (int) ClientDisplay->VerticalResolution, - (int) ClientDisplay->ColorDepth); + printf("\tWTSClientDisplay: HorizontalResolution: %u VerticalResolution: %u ColorDepth: %u\n", + ClientDisplay->HorizontalResolution, ClientDisplay->VerticalResolution, + ClientDisplay->ColorDepth); /* WTSClientProtocolType */ @@ -182,7 +198,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[]) } ClientProtocolType = *((USHORT*) pBuffer); - printf("\tWTSClientProtocolType: %d\n", (int) ClientProtocolType); + printf("\tWTSClientProtocolType: %u\n", ClientProtocolType); } WTSFreeMemory(pSessionInfo); diff --git a/winpr/libwinpr/wtsapi/test/TestWtsApiSessionNotification.c b/winpr/libwinpr/wtsapi/test/TestWtsApiSessionNotification.c index 03cf8a5..3e9b047 100644 --- a/winpr/libwinpr/wtsapi/test/TestWtsApiSessionNotification.c +++ b/winpr/libwinpr/wtsapi/test/TestWtsApiSessionNotification.c @@ -2,30 +2,54 @@ #include #include #include +#include int TestWtsApiSessionNotification(int argc, char* argv[]) { - HWND hWnd; + HWND hWnd = NULL; BOOL bSuccess; DWORD dwFlags; - hWnd = NULL; +#ifndef _WIN32 + if (!GetEnvironmentVariableA("WTSAPI_LIBRARY", NULL, 0)) + { + printf("%s: No RDS environment detected, skipping test\n", __FUNCTION__); + return 0; + } +#else + /* We create a message-only window and use the predefined class name "STATIC" for simplicity */ + hWnd = CreateWindowA("STATIC", "TestWtsApiSessionNotification", 0, 0, 0, 0, 0, HWND_MESSAGE, NULL, NULL, NULL); + if (!hWnd) + { + printf("%s: error creating message-only window: %u\n", __FUNCTION__, GetLastError()); + return -1; + } +#endif + dwFlags = NOTIFY_FOR_ALL_SESSIONS; bSuccess = WTSRegisterSessionNotification(hWnd, dwFlags); if (!bSuccess) { - printf("WTSRegisterSessionNotification failed: %d\n", (int) GetLastError()); - //return -1; + printf("%s: WTSRegisterSessionNotification failed: %u\n", __FUNCTION__, GetLastError()); + return -1; } bSuccess = WTSUnRegisterSessionNotification(hWnd); +#ifdef _WIN32 + if (hWnd) + { + DestroyWindow(hWnd); + hWnd = NULL; + } +#endif + if (!bSuccess) { - printf("WTSUnRegisterSessionNotification failed: %d\n", (int) GetLastError()); - //return -1; + printf("%s: WTSUnRegisterSessionNotification failed: %u\n", __FUNCTION__, GetLastError()); + return -1; } return 0; diff --git a/winpr/libwinpr/wtsapi/test/TestWtsApiShutdownSystem.c b/winpr/libwinpr/wtsapi/test/TestWtsApiShutdownSystem.c index 9157d5a..80d5335 100644 --- a/winpr/libwinpr/wtsapi/test/TestWtsApiShutdownSystem.c +++ b/winpr/libwinpr/wtsapi/test/TestWtsApiShutdownSystem.c @@ -1,7 +1,7 @@ - #include #include #include +#include int TestWtsApiShutdownSystem(int argc, char* argv[]) { @@ -9,6 +9,14 @@ int TestWtsApiShutdownSystem(int argc, char* argv[]) HANDLE hServer; DWORD ShutdownFlag; +#ifndef _WIN32 + if (!GetEnvironmentVariableA("WTSAPI_LIBRARY", NULL, 0)) + { + printf("%s: No RDS environment detected, skipping test\n", __FUNCTION__); + return 0; + } +#endif + hServer = WTS_CURRENT_SERVER_HANDLE; ShutdownFlag = WTS_WSD_SHUTDOWN; @@ -17,7 +25,7 @@ int TestWtsApiShutdownSystem(int argc, char* argv[]) if (!bSuccess) { printf("WTSShutdownSystem failed: %d\n", (int) GetLastError()); - //return -1; + return -1; } return 0; diff --git a/winpr/libwinpr/wtsapi/test/TestWtsApiWaitSystemEvent.c b/winpr/libwinpr/wtsapi/test/TestWtsApiWaitSystemEvent.c index 44b9cb3..e42879f 100644 --- a/winpr/libwinpr/wtsapi/test/TestWtsApiWaitSystemEvent.c +++ b/winpr/libwinpr/wtsapi/test/TestWtsApiWaitSystemEvent.c @@ -2,6 +2,7 @@ #include #include #include +#include int TestWtsApiWaitSystemEvent(int argc, char* argv[]) { @@ -10,6 +11,14 @@ int TestWtsApiWaitSystemEvent(int argc, char* argv[]) DWORD eventMask; DWORD eventFlags; +#ifndef _WIN32 + if (!GetEnvironmentVariableA("WTSAPI_LIBRARY", NULL, 0)) + { + printf("%s: No RDS environment detected, skipping test\n", __FUNCTION__); + return 0; + } +#endif + hServer = WTS_CURRENT_SERVER_HANDLE; eventMask = WTS_EVENT_ALL; @@ -20,7 +29,7 @@ int TestWtsApiWaitSystemEvent(int argc, char* argv[]) if (!bSuccess) { printf("WTSWaitSystemEvent failed: %d\n", (int) GetLastError()); - //return -1; + return -1; } return 0; diff --git a/winpr/libwinpr/wtsapi/wtsapi.c b/winpr/libwinpr/wtsapi/wtsapi.c index 59f8e86..024423f 100644 --- a/winpr/libwinpr/wtsapi/wtsapi.c +++ b/winpr/libwinpr/wtsapi/wtsapi.c @@ -32,8 +32,6 @@ #include -#include "wtsapi.h" - #ifdef _WIN32 #include "wtsapi_win32.h" #endif @@ -46,9 +44,6 @@ * http://msdn.microsoft.com/en-us/library/windows/desktop/aa383464/ */ -void InitializeWtsApiStubs(void); - -static BOOL g_Initialized = FALSE; static HMODULE g_WtsApiModule = NULL; static PWtsApiFunctionTable g_WtsApi = NULL; @@ -128,12 +123,12 @@ static WtsApiFunctionTable WtsApi32_WtsApiFunctionTable = #define WTSAPI32_LOAD_PROC(_name, _type) \ WtsApi32_WtsApiFunctionTable.p ## _name = (## _type) GetProcAddress(g_WtsApi32Module, "WTS" #_name); -int WtsApi32_InitializeWtsApi(void) +BOOL WtsApi32_InitializeWtsApi(void) { g_WtsApi32Module = LoadLibraryA("wtsapi32.dll"); if (!g_WtsApi32Module) - return -1; + return FALSE; #ifdef _WIN32 WTSAPI32_LOAD_PROC(StopRemoteControlSession, WTS_STOP_REMOTE_CONTROL_SESSION_FN); @@ -205,11 +200,33 @@ int WtsApi32_InitializeWtsApi(void) g_WtsApi = &WtsApi32_WtsApiFunctionTable; - return 1; + return TRUE; } /* WtsApi Functions */ +static BOOL CALLBACK InitializeWtsApiStubs(PINIT_ONCE once, PVOID param, PVOID *context); +static INIT_ONCE wtsapiInitOnce = INIT_ONCE_STATIC_INIT; + +#define WTSAPI_STUB_CALL_VOID(_name, ...) \ + InitOnceExecuteOnce(&wtsapiInitOnce, InitializeWtsApiStubs, NULL, NULL); \ + if (!g_WtsApi || !g_WtsApi->p ## _name) \ + return; \ + g_WtsApi->p ## _name ( __VA_ARGS__ ) + +#define WTSAPI_STUB_CALL_BOOL(_name, ...) \ + InitOnceExecuteOnce(&wtsapiInitOnce, InitializeWtsApiStubs, NULL, NULL); \ + if (!g_WtsApi || !g_WtsApi->p ## _name) \ + return FALSE; \ + return g_WtsApi->p ## _name ( __VA_ARGS__ ) + +#define WTSAPI_STUB_CALL_HANDLE(_name, ...) \ + InitOnceExecuteOnce(&wtsapiInitOnce, InitializeWtsApiStubs, NULL, NULL); \ + if (!g_WtsApi || !g_WtsApi->p ## _name) \ + return NULL; \ + return g_WtsApi->p ## _name ( __VA_ARGS__ ) + + BOOL WINAPI WTSStartRemoteControlSessionW(LPWSTR pTargetServerName, ULONG TargetLogonId, BYTE HotkeyVk, USHORT HotkeyModifiers) { WTSAPI_STUB_CALL_BOOL(StartRemoteControlSessionW, pTargetServerName, TargetLogonId, HotkeyVk, HotkeyModifiers); @@ -566,8 +583,7 @@ BOOL CDECL WTSLogoffUser(HANDLE hServer) DWORD WINAPI WTSGetActiveConsoleSessionId(void) { - if (!g_Initialized) - InitializeWtsApiStubs(); + InitOnceExecuteOnce(&wtsapiInitOnce, InitializeWtsApiStubs, NULL, NULL); if (!g_WtsApi || !g_WtsApi->pGetActiveConsoleSessionId) return 0xFFFFFFFF; @@ -649,10 +665,37 @@ const CHAR* WTSErrorToString(UINT error) } } +const CHAR* WTSSessionStateToString(WTS_CONNECTSTATE_CLASS state) +{ + switch (state) + { + case WTSActive: + return "WTSActive"; + case WTSConnected: + return "WTSConnected"; + case WTSConnectQuery: + return "WTSConnectQuery"; + case WTSShadow: + return "WTSShadow"; + case WTSDisconnected: + return "WTSDisconnected"; + case WTSIdle: + return "WTSIdle"; + case WTSListen: + return "WTSListen"; + case WTSReset: + return "WTSReset"; + case WTSDown: + return "WTSDown"; + case WTSInit: + return "WTSInit"; + } + return "INVALID_STATE"; +} + BOOL WTSRegisterWtsApiFunctionTable(PWtsApiFunctionTable table) { g_WtsApi = table; - g_Initialized = TRUE; return TRUE; } @@ -675,7 +718,7 @@ static BOOL LoadAndInitialize(char* library) return TRUE; } -void InitializeWtsApiStubs_Env() +static void InitializeWtsApiStubs_Env() { DWORD nSize; char *env = NULL; @@ -699,7 +742,7 @@ void InitializeWtsApiStubs_Env() #define FREERDS_LIBRARY_NAME "libfreerds-fdsapi.so" -void InitializeWtsApiStubs_FreeRDS() +static void InitializeWtsApiStubs_FreeRDS() { wIniFile* ini; const char* prefix; @@ -741,12 +784,9 @@ void InitializeWtsApiStubs_FreeRDS() IniFile_Free(ini); } -void InitializeWtsApiStubs(void) -{ - if (g_Initialized) - return; - g_Initialized = TRUE; +static BOOL CALLBACK InitializeWtsApiStubs(PINIT_ONCE once, PVOID param, PVOID *context) +{ InitializeWtsApiStubs_Env(); #ifdef _WIN32 @@ -756,5 +796,5 @@ void InitializeWtsApiStubs(void) if (!g_WtsApi) InitializeWtsApiStubs_FreeRDS(); - return; + return TRUE; } diff --git a/winpr/libwinpr/wtsapi/wtsapi.h b/winpr/libwinpr/wtsapi/wtsapi.h deleted file mode 100644 index d83ce20..0000000 --- a/winpr/libwinpr/wtsapi/wtsapi.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * WinPR: Windows Portable Runtime - * Windows Terminal Services API - * - * Copyright 2013 Marc-Andre Moreau - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef WINPR_WTSAPI_PRIVATE_H -#define WINPR_WTSAPI_PRIVATE_H - -#define WTSAPI_STUB_CALL_VOID(_name, ...) \ - if (!g_Initialized) \ - InitializeWtsApiStubs(); \ - if (!g_WtsApi || !g_WtsApi->p ## _name) \ - return; \ - g_WtsApi->p ## _name ( __VA_ARGS__ ) - -#define WTSAPI_STUB_CALL_BOOL(_name, ...) \ - if (!g_Initialized) \ - InitializeWtsApiStubs(); \ - if (!g_WtsApi || !g_WtsApi->p ## _name) \ - return FALSE; \ - return g_WtsApi->p ## _name ( __VA_ARGS__ ) - -#define WTSAPI_STUB_CALL_HANDLE(_name, ...) \ - if (!g_Initialized) \ - InitializeWtsApiStubs(); \ - if (!g_WtsApi || !g_WtsApi->p ## _name) \ - return NULL; \ - return g_WtsApi->p ## _name ( __VA_ARGS__ ) - -#endif /* WINPR_WTSAPI_PRIVATE_H */ diff --git a/winpr/libwinpr/wtsapi/wtsapi_win32.c b/winpr/libwinpr/wtsapi/wtsapi_win32.c index d6d708c..2a2556b 100644 --- a/winpr/libwinpr/wtsapi/wtsapi_win32.c +++ b/winpr/libwinpr/wtsapi/wtsapi_win32.c @@ -30,7 +30,6 @@ #include "wtsapi_win32.h" -#include "wtsapi.h" #include "../log.h" #define WTSAPI_CHANNEL_MAGIC 0x44484356 @@ -72,6 +71,37 @@ static fnWinStationVirtualOpenEx pfnWinStationVirtualOpenEx = NULL; BOOL WINAPI Win32_WTSVirtualChannelClose(HANDLE hChannel); + +/** + * NOTE !! + * An application using the WinPR wtsapi frees memory via WTSFreeMemory, which + * might be mapped to Win32_WTSFreeMemory. Latter does not know if the passed + * pointer was allocated by a function in wtsapi32.dll or in some internal + * code below. The WTSFreeMemory implementation in all Windows wtsapi32.dll + * versions up to Windows 10 uses LocalFree since all its allocating functions + * use LocalAlloc() internally. + * For that reason we also have to use LocalAlloc() for any memory returned by + * our WinPR wtsapi functions. + * + * To be safe we only use the _wts_malloc, _wts_calloc, _wts_free wrappers + * for memory managment the code below. + */ + +static void *_wts_malloc(size_t size) +{ + return (PVOID)LocalAlloc(LMEM_FIXED, size); +} + +static void *_wts_calloc(size_t nmemb, size_t size) +{ + return (PVOID)LocalAlloc(LMEM_FIXED | LMEM_ZEROINIT, nmemb * size); +} + +static void _wts_free(void* ptr) +{ + LocalFree((HLOCAL)ptr); +} + BOOL Win32_WTSVirtualChannelReadAsync(WTSAPI_CHANNEL* pChannel) { BOOL status = TRUE; @@ -130,19 +160,28 @@ HANDLE WINAPI Win32_WTSVirtualChannelOpen_Internal(HANDLE hServer, DWORD Session HANDLE hFile; HANDLE hChannel; WTSAPI_CHANNEL* pChannel; + size_t virtualNameLen; + + virtualNameLen = pVirtualName ? strlen(pVirtualName) : 0; - if (!pVirtualName) + if (!virtualNameLen) { SetLastError(ERROR_INVALID_PARAMETER); return NULL; } + if (!pfnWinStationVirtualOpenEx) + { + SetLastError(ERROR_INVALID_FUNCTION); + return NULL; + } + hFile = pfnWinStationVirtualOpenEx(hServer, SessionId, pVirtualName, flags); if (!hFile) return NULL; - pChannel = (WTSAPI_CHANNEL*) calloc(1, sizeof(WTSAPI_CHANNEL)); + pChannel = (WTSAPI_CHANNEL*) _wts_calloc(1, sizeof(WTSAPI_CHANNEL)); if (!pChannel) { @@ -156,14 +195,15 @@ HANDLE WINAPI Win32_WTSVirtualChannelOpen_Internal(HANDLE hServer, DWORD Session pChannel->hServer = hServer; pChannel->SessionId = SessionId; pChannel->hFile = hFile; - pChannel->VirtualName = _strdup(pVirtualName); + pChannel->VirtualName = _wts_calloc(1, virtualNameLen + 1); if (!pChannel->VirtualName) { CloseHandle(hFile); SetLastError(ERROR_NOT_ENOUGH_MEMORY); - free(pChannel); + _wts_free(pChannel); return NULL; } + memcpy(pChannel->VirtualName, pVirtualName, virtualNameLen); pChannel->flags = flags; pChannel->dynamic = (flags & WTS_CHANNEL_OPTION_DYNAMIC) ? TRUE : FALSE; @@ -171,7 +211,7 @@ HANDLE WINAPI Win32_WTSVirtualChannelOpen_Internal(HANDLE hServer, DWORD Session pChannel->showProtocol = pChannel->dynamic; pChannel->readSize = CHANNEL_PDU_LENGTH; - pChannel->readBuffer = (BYTE*) malloc(pChannel->readSize); + pChannel->readBuffer = (BYTE*) _wts_malloc(pChannel->readSize); pChannel->header = (CHANNEL_PDU_HEADER*) pChannel->readBuffer; pChannel->chunk = &(pChannel->readBuffer[sizeof(CHANNEL_PDU_HEADER)]); @@ -230,18 +270,18 @@ BOOL WINAPI Win32_WTSVirtualChannelClose(HANDLE hChannel) if (pChannel->VirtualName) { - free(pChannel->VirtualName); + _wts_free(pChannel->VirtualName); pChannel->VirtualName = NULL; } if (pChannel->readBuffer) { - free(pChannel->readBuffer); + _wts_free(pChannel->readBuffer); pChannel->readBuffer = NULL; } pChannel->magic = 0; - free(pChannel); + _wts_free(pChannel); return status; } @@ -660,7 +700,7 @@ BOOL WINAPI Win32_WTSVirtualChannelQuery(HANDLE hChannelHandle, WTS_VIRTUAL_CLAS else if (WtsVirtualClass == WTSVirtualFileHandle) { *pBytesReturned = sizeof(HANDLE); - *ppBuffer = calloc(1, *pBytesReturned); + *ppBuffer = _wts_calloc(1, *pBytesReturned); if (*ppBuffer == NULL) { @@ -673,7 +713,7 @@ BOOL WINAPI Win32_WTSVirtualChannelQuery(HANDLE hChannelHandle, WTS_VIRTUAL_CLAS else if (WtsVirtualClass == WTSVirtualEventHandle) { *pBytesReturned = sizeof(HANDLE); - *ppBuffer = calloc(1, *pBytesReturned); + *ppBuffer = _wts_calloc(1, *pBytesReturned); if (*ppBuffer == NULL) { @@ -697,12 +737,12 @@ BOOL WINAPI Win32_WTSVirtualChannelQuery(HANDLE hChannelHandle, WTS_VIRTUAL_CLAS VOID WINAPI Win32_WTSFreeMemory(PVOID pMemory) { - free(pMemory); + _wts_free(pMemory); } BOOL WINAPI Win32_WTSFreeMemoryExW(WTS_TYPE_CLASS WTSTypeClass, PVOID pMemory, ULONG NumberOfEntries) { - return TRUE; + return FALSE; } BOOL WINAPI Win32_WTSFreeMemoryExA(WTS_TYPE_CLASS WTSTypeClass, PVOID pMemory, ULONG NumberOfEntries) @@ -710,18 +750,18 @@ BOOL WINAPI Win32_WTSFreeMemoryExA(WTS_TYPE_CLASS WTSTypeClass, PVOID pMemory, U return WTSFreeMemoryExW(WTSTypeClass, pMemory, NumberOfEntries); } -int Win32_InitializeWinSta(PWtsApiFunctionTable pWtsApi) +BOOL Win32_InitializeWinSta(PWtsApiFunctionTable pWtsApi) { g_WinStaModule = LoadLibraryA("winsta.dll"); if (!g_WinStaModule) - return -1; + return FALSE; pfnWinStationVirtualOpen = (fnWinStationVirtualOpen) GetProcAddress(g_WinStaModule, "WinStationVirtualOpen"); pfnWinStationVirtualOpenEx = (fnWinStationVirtualOpenEx) GetProcAddress(g_WinStaModule, "WinStationVirtualOpenEx"); - if (!pfnWinStationVirtualOpenEx) - return -1; + if (!pfnWinStationVirtualOpen | !pfnWinStationVirtualOpenEx) + return FALSE; pWtsApi->pVirtualChannelOpen = Win32_WTSVirtualChannelOpen; pWtsApi->pVirtualChannelOpenEx = Win32_WTSVirtualChannelOpenEx; @@ -735,5 +775,5 @@ int Win32_InitializeWinSta(PWtsApiFunctionTable pWtsApi) //pWtsApi->pFreeMemoryExW = Win32_WTSFreeMemoryExW; //pWtsApi->pFreeMemoryExA = Win32_WTSFreeMemoryExA; - return 1; + return TRUE; } diff --git a/winpr/libwinpr/wtsapi/wtsapi_win32.h b/winpr/libwinpr/wtsapi/wtsapi_win32.h index c6add2f..7d43165 100644 --- a/winpr/libwinpr/wtsapi/wtsapi_win32.h +++ b/winpr/libwinpr/wtsapi/wtsapi_win32.h @@ -22,6 +22,6 @@ #include -int Win32_InitializeWinSta(PWtsApiFunctionTable pWtsApi); +BOOL Win32_InitializeWinSta(PWtsApiFunctionTable pWtsApi); #endif /* WINPR_WTSAPI_WIN32_PRIVATE_H */ -- 2.7.4