winpr/wtsapi: fixed race conditions and tests
authorNorbert Federa <norbert.federa@thincast.com>
Mon, 30 May 2016 15:54:59 +0000 (17:54 +0200)
committerNorbert Federa <norbert.federa@thincast.com>
Mon, 30 May 2016 15:54:59 +0000 (17:54 +0200)
12 files changed:
winpr/include/winpr/wtsapi.h
winpr/libwinpr/wtsapi/CMakeLists.txt
winpr/libwinpr/wtsapi/test/TestWtsApiEnumerateProcesses.c
winpr/libwinpr/wtsapi/test/TestWtsApiEnumerateSessions.c
winpr/libwinpr/wtsapi/test/TestWtsApiQuerySessionInformation.c
winpr/libwinpr/wtsapi/test/TestWtsApiSessionNotification.c
winpr/libwinpr/wtsapi/test/TestWtsApiShutdownSystem.c
winpr/libwinpr/wtsapi/test/TestWtsApiWaitSystemEvent.c
winpr/libwinpr/wtsapi/wtsapi.c
winpr/libwinpr/wtsapi/wtsapi.h [deleted file]
winpr/libwinpr/wtsapi/wtsapi_win32.c
winpr/libwinpr/wtsapi/wtsapi_win32.h

index a623a33..cc50b3e 100644 (file)
@@ -1389,6 +1389,7 @@ extern "C" {
 
 WINPR_API BOOL WTSRegisterWtsApiFunctionTable(PWtsApiFunctionTable table);
 WINPR_API const CHAR* WTSErrorToString(UINT error);
+WINPR_API const CHAR* WTSSessionStateToString(WTS_CONNECTSTATE_CLASS state);
 
 #ifdef __cplusplus
 }
index 71c1646..1d8a403 100644 (file)
@@ -15,8 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-winpr_module_add(wtsapi.c
-       wtsapi.h)
+winpr_module_add(wtsapi.c)
 
 if(WIN32)
        winpr_module_add(wtsapi_win32.c wtsapi_win32.h)
index ea60ac7..64f7ba0 100644 (file)
@@ -2,26 +2,44 @@
 #include <winpr/crt.h>
 #include <winpr/error.h>
 #include <winpr/wtsapi.h>
+#include <winpr/environment.h>
 
 int TestWtsApiEnumerateProcesses(int argc, char* argv[])
 {
        DWORD count;
        BOOL bSuccess;
        HANDLE hServer;
-       PWTS_PROCESS_INFO pProcessInfo;
+       PWTS_PROCESS_INFOA pProcessInfo;
+
+#ifndef _WIN32
+       if (!GetEnvironmentVariableA("WTSAPI_LIBRARY", NULL, 0))
+       {
+               printf("%s: No RDS environment detected, skipping test\n", __FUNCTION__);
+               return 0;
+       }
+#endif
 
        hServer = WTS_CURRENT_SERVER_HANDLE;
 
        count = 0;
        pProcessInfo = NULL;
 
-       bSuccess = WTSEnumerateProcesses(hServer, 0, 1, &pProcessInfo, &count);
+       bSuccess = WTSEnumerateProcessesA(hServer, 0, 1, &pProcessInfo, &count);
 
        if (!bSuccess)
        {
-               printf("WTSEnumerateProcesses failed: %d\n", (int) GetLastError());
-               //return -1;
+               printf("WTSEnumerateProcesses failed: %u\n", GetLastError());
+               return -1;
+       }
+
+#if 0
+       {
+               DWORD i;
+               printf("WTSEnumerateProcesses enumerated %u processs:\n", count);
+               for (i = 0; i < count; i++)
+                       printf("\t[%u]: %s (%lu)\n", i, pProcessInfo[i].pProcessName, pProcessInfo[i].ProcessId);
        }
+#endif
 
        WTSFreeMemory(pProcessInfo);
 
index b52e10b..bec10bd 100644 (file)
@@ -2,6 +2,7 @@
 #include <winpr/crt.h>
 #include <winpr/error.h>
 #include <winpr/wtsapi.h>
+#include <winpr/environment.h>
 
 int TestWtsApiEnumerateSessions(int argc, char* argv[])
 {
@@ -9,14 +10,22 @@ int TestWtsApiEnumerateSessions(int argc, char* argv[])
        DWORD count;
        BOOL bSuccess;
        HANDLE hServer;
-       PWTS_SESSION_INFO pSessionInfo;
+       PWTS_SESSION_INFOA pSessionInfo;
+
+#ifndef _WIN32
+       if (!GetEnvironmentVariableA("WTSAPI_LIBRARY", NULL, 0))
+       {
+               printf("%s: No RDS environment detected, skipping test\n", __FUNCTION__);
+               return 0;
+       }
+#endif
 
        hServer = WTS_CURRENT_SERVER_HANDLE;
 
        count = 0;
        pSessionInfo = NULL;
 
-       bSuccess = WTSEnumerateSessions(hServer, 0, 1, &pSessionInfo, &count);
+       bSuccess = WTSEnumerateSessionsA(hServer, 0, 1, &pSessionInfo, &count);
 
        if (!bSuccess)
        {
@@ -28,9 +37,11 @@ int TestWtsApiEnumerateSessions(int argc, char* argv[])
 
        for (index = 0; index < count; index++)
        {
-               printf("[%d] SessionId: %d State: %d\n", (int) index,
-                               (int) pSessionInfo[index].SessionId,
-                               (int) pSessionInfo[index].State);
+               printf("[%u] SessionId: %u WinstationName: '%s' State: %s (%u)\n", index,
+                       pSessionInfo[index].SessionId,
+                       pSessionInfo[index].pWinStationName,
+                       WTSSessionStateToString(pSessionInfo[index].State),
+                       pSessionInfo[index].State);
        }
 
        WTSFreeMemory(pSessionInfo);
index 7212869..8351806 100644 (file)
@@ -2,24 +2,33 @@
 #include <winpr/crt.h>
 #include <winpr/error.h>
 #include <winpr/wtsapi.h>
+#include <winpr/environment.h>
 
 int TestWtsApiQuerySessionInformation(int argc, char* argv[])
 {
-       DWORD index;
+       DWORD index, i;
        DWORD count;
        BOOL bSuccess;
        HANDLE hServer;
        LPSTR pBuffer;
        DWORD sessionId;
        DWORD bytesReturned;
-       PWTS_SESSION_INFO pSessionInfo;
+       PWTS_SESSION_INFOA pSessionInfo;
+
+#ifndef _WIN32
+       if (!GetEnvironmentVariableA("WTSAPI_LIBRARY", NULL, 0))
+       {
+               printf("%s: No RDS environment detected, skipping test\n", __FUNCTION__);
+               return 0;
+       }
+#endif
 
        hServer = WTS_CURRENT_SERVER_HANDLE;
 
        count = 0;
        pSessionInfo = NULL;
 
-       bSuccess = WTSEnumerateSessions(hServer, 0, 1, &pSessionInfo, &count);
+       bSuccess = WTSEnumerateSessionsA(hServer, 0, 1, &pSessionInfo, &count);
 
        if (!bSuccess)
        {
@@ -47,9 +56,12 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
 
                sessionId = pSessionInfo[index].SessionId;
 
-               printf("[%d] SessionId: %d State: %d\n", (int) index,
-                               (int) pSessionInfo[index].SessionId,
-                               (int) pSessionInfo[index].State);
+               printf("[%u] SessionId: %u State: %s (%u) WinstationName: '%s'\n",
+                               index,
+                               pSessionInfo[index].SessionId,
+                               WTSSessionStateToString(pSessionInfo[index].State),
+                               pSessionInfo[index].State,
+                               pSessionInfo[index].pWinStationName);
 
                /* WTSUserName */
 
@@ -62,7 +74,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
                }
 
                Username = (char*) pBuffer;
-               printf("\tWTSUserName: %s\n", Username);
+               printf("\tWTSUserName: '%s'\n", Username);
 
                /* WTSDomainName */
 
@@ -75,7 +87,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
                }
 
                Domain = (char*) pBuffer;
-               printf("\tWTSDomainName: %s\n", Domain);
+               printf("\tWTSDomainName: '%s'\n", Domain);
 
                /* WTSConnectState */
 
@@ -88,7 +100,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
                }
 
                ConnectState = *((WTS_CONNECTSTATE_CLASS*) pBuffer);
-               printf("\tWTSConnectState: %d\n", (int) ConnectState);
+               printf("\tWTSConnectState: %u (%s)\n", ConnectState, WTSSessionStateToString(ConnectState));
 
                /* WTSClientBuildNumber */
 
@@ -101,7 +113,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
                }
 
                ClientBuildNumber = *((ULONG*) pBuffer);
-               printf("\tWTSClientBuildNumber: %d\n", (int) ClientBuildNumber);
+               printf("\tWTSClientBuildNumber: %u\n", ClientBuildNumber);
 
                /* WTSClientName */
 
@@ -114,7 +126,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
                }
 
                ClientName = (char*) pBuffer;
-               printf("\tWTSClientName: %s\n", ClientName);
+               printf("\tWTSClientName: '%s'\n", ClientName);
 
                /* WTSClientProductId */
 
@@ -127,7 +139,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
                }
 
                ClientProductId = *((USHORT*) pBuffer);
-               printf("\tWTSClientProductId: %d\n", (int) ClientProductId);
+               printf("\tWTSClientProductId: %u\n", ClientProductId);
 
                /* WTSClientHardwareId */
 
@@ -140,7 +152,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
                }
 
                ClientHardwareId = *((ULONG*) pBuffer);
-               printf("\tWTSClientHardwareId: %d\n", (int) ClientHardwareId);
+               printf("\tWTSClientHardwareId: %u\n", ClientHardwareId);
 
                /* WTSClientAddress */
 
@@ -153,8 +165,12 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
                }
 
                ClientAddress = (PWTS_CLIENT_ADDRESS) pBuffer;
-               printf("\tWTSClientAddress: AddressFamily: %d\n",
-                               (int) ClientAddress->AddressFamily);
+               printf("\tWTSClientAddress: AddressFamily: %u Address: ",
+                               ClientAddress->AddressFamily);
+               for (i = 0; i < sizeof(ClientAddress->Address); i++)
+                       printf("%02X", ClientAddress->Address[i]);
+               printf("\n");
+
 
                /* WTSClientDisplay */
 
@@ -167,9 +183,9 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
                }
 
                ClientDisplay = (PWTS_CLIENT_DISPLAY) pBuffer;
-               printf("\tWTSClientDisplay: HorizontalResolution: %d VerticalResolution: %d ColorDepth: %d\n",
-                               (int) ClientDisplay->HorizontalResolution, (int) ClientDisplay->VerticalResolution,
-                               (int) ClientDisplay->ColorDepth);
+               printf("\tWTSClientDisplay: HorizontalResolution: %u VerticalResolution: %u ColorDepth: %u\n",
+                               ClientDisplay->HorizontalResolution, ClientDisplay->VerticalResolution,
+                               ClientDisplay->ColorDepth);
 
                /* WTSClientProtocolType */
 
@@ -182,7 +198,7 @@ int TestWtsApiQuerySessionInformation(int argc, char* argv[])
                }
 
                ClientProtocolType = *((USHORT*) pBuffer);
-               printf("\tWTSClientProtocolType: %d\n", (int) ClientProtocolType);
+               printf("\tWTSClientProtocolType: %u\n", ClientProtocolType);
        }
 
        WTSFreeMemory(pSessionInfo);
index 03cf8a5..3e9b047 100644 (file)
@@ -2,30 +2,54 @@
 #include <winpr/crt.h>
 #include <winpr/error.h>
 #include <winpr/wtsapi.h>
+#include <winpr/environment.h>
 
 int TestWtsApiSessionNotification(int argc, char* argv[])
 {
-       HWND hWnd;
+       HWND hWnd = NULL;
        BOOL bSuccess;
        DWORD dwFlags;
 
-       hWnd = NULL;
+#ifndef _WIN32
+       if (!GetEnvironmentVariableA("WTSAPI_LIBRARY", NULL, 0))
+       {
+               printf("%s: No RDS environment detected, skipping test\n", __FUNCTION__);
+               return 0;
+       }
+#else
+       /* We create a message-only window and use the predefined class name "STATIC" for simplicity */
+       hWnd = CreateWindowA("STATIC", "TestWtsApiSessionNotification", 0, 0, 0, 0, 0, HWND_MESSAGE, NULL, NULL, NULL);
+       if (!hWnd)
+       {
+               printf("%s: error creating message-only window: %u\n", __FUNCTION__, GetLastError());
+               return -1;
+       }
+#endif
+
        dwFlags = NOTIFY_FOR_ALL_SESSIONS;
 
        bSuccess = WTSRegisterSessionNotification(hWnd, dwFlags);
 
        if (!bSuccess)
        {
-               printf("WTSRegisterSessionNotification failed: %d\n", (int) GetLastError());
-               //return -1;
+               printf("%s: WTSRegisterSessionNotification failed: %u\n", __FUNCTION__, GetLastError());
+               return -1;
        }
 
        bSuccess = WTSUnRegisterSessionNotification(hWnd);
 
+#ifdef _WIN32
+       if (hWnd)
+       {
+               DestroyWindow(hWnd);
+               hWnd = NULL;
+       }
+#endif
+
        if (!bSuccess)
        {
-               printf("WTSUnRegisterSessionNotification failed: %d\n", (int) GetLastError());
-               //return -1;
+               printf("%s: WTSUnRegisterSessionNotification failed: %u\n", __FUNCTION__, GetLastError());
+               return -1;
        }
 
        return 0;
index 9157d5a..80d5335 100644 (file)
@@ -1,7 +1,7 @@
-
 #include <winpr/crt.h>
 #include <winpr/error.h>
 #include <winpr/wtsapi.h>
+#include <winpr/environment.h>
 
 int TestWtsApiShutdownSystem(int argc, char* argv[])
 {
@@ -9,6 +9,14 @@ int TestWtsApiShutdownSystem(int argc, char* argv[])
        HANDLE hServer;
        DWORD ShutdownFlag;
 
+#ifndef _WIN32
+       if (!GetEnvironmentVariableA("WTSAPI_LIBRARY", NULL, 0))
+       {
+               printf("%s: No RDS environment detected, skipping test\n", __FUNCTION__);
+               return 0;
+       }
+#endif
+
        hServer = WTS_CURRENT_SERVER_HANDLE;
        ShutdownFlag = WTS_WSD_SHUTDOWN;
 
@@ -17,7 +25,7 @@ int TestWtsApiShutdownSystem(int argc, char* argv[])
        if (!bSuccess)
        {
                printf("WTSShutdownSystem failed: %d\n", (int) GetLastError());
-               //return -1;
+               return -1;
        }
 
        return 0;
index 44b9cb3..e42879f 100644 (file)
@@ -2,6 +2,7 @@
 #include <winpr/crt.h>
 #include <winpr/error.h>
 #include <winpr/wtsapi.h>
+#include <winpr/environment.h>
 
 int TestWtsApiWaitSystemEvent(int argc, char* argv[])
 {
@@ -10,6 +11,14 @@ int TestWtsApiWaitSystemEvent(int argc, char* argv[])
        DWORD eventMask;
        DWORD eventFlags;
 
+#ifndef _WIN32
+       if (!GetEnvironmentVariableA("WTSAPI_LIBRARY", NULL, 0))
+       {
+               printf("%s: No RDS environment detected, skipping test\n", __FUNCTION__);
+               return 0;
+       }
+#endif
+
        hServer = WTS_CURRENT_SERVER_HANDLE;
 
        eventMask = WTS_EVENT_ALL;
@@ -20,7 +29,7 @@ int TestWtsApiWaitSystemEvent(int argc, char* argv[])
        if (!bSuccess)
        {
                printf("WTSWaitSystemEvent failed: %d\n", (int) GetLastError());
-               //return -1;
+               return -1;
        }
 
        return 0;
index 59f8e86..024423f 100644 (file)
@@ -32,8 +32,6 @@
 
 #include <winpr/wtsapi.h>
 
-#include "wtsapi.h"
-
 #ifdef _WIN32
 #include "wtsapi_win32.h"
 #endif
@@ -46,9 +44,6 @@
  * http://msdn.microsoft.com/en-us/library/windows/desktop/aa383464/
  */
 
-void InitializeWtsApiStubs(void);
-
-static BOOL g_Initialized = FALSE;
 static HMODULE g_WtsApiModule = NULL;
 
 static PWtsApiFunctionTable g_WtsApi = NULL;
@@ -128,12 +123,12 @@ static WtsApiFunctionTable WtsApi32_WtsApiFunctionTable =
 #define WTSAPI32_LOAD_PROC(_name, _type) \
        WtsApi32_WtsApiFunctionTable.p ## _name = (## _type) GetProcAddress(g_WtsApi32Module, "WTS" #_name);
 
-int WtsApi32_InitializeWtsApi(void)
+BOOL WtsApi32_InitializeWtsApi(void)
 {
        g_WtsApi32Module = LoadLibraryA("wtsapi32.dll");
 
        if (!g_WtsApi32Module)
-               return -1;
+               return FALSE;
 
 #ifdef _WIN32
        WTSAPI32_LOAD_PROC(StopRemoteControlSession, WTS_STOP_REMOTE_CONTROL_SESSION_FN);
@@ -205,11 +200,33 @@ int WtsApi32_InitializeWtsApi(void)
 
        g_WtsApi = &WtsApi32_WtsApiFunctionTable;
 
-       return 1;
+       return TRUE;
 }
 
 /* WtsApi Functions */
 
+static BOOL CALLBACK InitializeWtsApiStubs(PINIT_ONCE once, PVOID param, PVOID *context);
+static INIT_ONCE wtsapiInitOnce = INIT_ONCE_STATIC_INIT;
+
+#define WTSAPI_STUB_CALL_VOID(_name, ...) \
+       InitOnceExecuteOnce(&wtsapiInitOnce, InitializeWtsApiStubs, NULL, NULL); \
+       if (!g_WtsApi || !g_WtsApi->p ## _name) \
+               return; \
+       g_WtsApi->p ## _name ( __VA_ARGS__ )
+
+#define WTSAPI_STUB_CALL_BOOL(_name, ...) \
+       InitOnceExecuteOnce(&wtsapiInitOnce, InitializeWtsApiStubs, NULL, NULL); \
+       if (!g_WtsApi || !g_WtsApi->p ## _name) \
+               return FALSE; \
+       return g_WtsApi->p ## _name ( __VA_ARGS__ )
+
+#define WTSAPI_STUB_CALL_HANDLE(_name, ...) \
+       InitOnceExecuteOnce(&wtsapiInitOnce, InitializeWtsApiStubs, NULL, NULL); \
+       if (!g_WtsApi || !g_WtsApi->p ## _name) \
+               return NULL; \
+       return g_WtsApi->p ## _name ( __VA_ARGS__ )
+
+
 BOOL WINAPI WTSStartRemoteControlSessionW(LPWSTR pTargetServerName, ULONG TargetLogonId, BYTE HotkeyVk, USHORT HotkeyModifiers)
 {
        WTSAPI_STUB_CALL_BOOL(StartRemoteControlSessionW, pTargetServerName, TargetLogonId, HotkeyVk, HotkeyModifiers);
@@ -566,8 +583,7 @@ BOOL CDECL WTSLogoffUser(HANDLE hServer)
 
 DWORD WINAPI WTSGetActiveConsoleSessionId(void)
 {
-       if (!g_Initialized)
-               InitializeWtsApiStubs();
+       InitOnceExecuteOnce(&wtsapiInitOnce, InitializeWtsApiStubs, NULL, NULL);
 
        if (!g_WtsApi || !g_WtsApi->pGetActiveConsoleSessionId)
                return 0xFFFFFFFF;
@@ -649,10 +665,37 @@ const CHAR* WTSErrorToString(UINT error)
        }
 }
 
+const CHAR* WTSSessionStateToString(WTS_CONNECTSTATE_CLASS state)
+{
+       switch (state)
+       {
+       case WTSActive:
+               return "WTSActive";
+       case WTSConnected:
+               return "WTSConnected";
+       case WTSConnectQuery:
+               return "WTSConnectQuery";
+       case WTSShadow:
+               return "WTSShadow";
+       case WTSDisconnected:
+               return "WTSDisconnected";
+       case WTSIdle:
+               return "WTSIdle";
+       case WTSListen:
+               return "WTSListen";
+       case WTSReset:
+               return "WTSReset";
+       case WTSDown:
+               return "WTSDown";
+       case WTSInit:
+               return "WTSInit";
+       }
+       return "INVALID_STATE";
+}
+
 BOOL WTSRegisterWtsApiFunctionTable(PWtsApiFunctionTable table)
 {
        g_WtsApi = table;
-       g_Initialized = TRUE;
        return TRUE;
 }
 
@@ -675,7 +718,7 @@ static BOOL LoadAndInitialize(char* library)
        return TRUE;
 }
 
-void InitializeWtsApiStubs_Env()
+static void InitializeWtsApiStubs_Env()
 {
        DWORD nSize;
        char *env = NULL;
@@ -699,7 +742,7 @@ void InitializeWtsApiStubs_Env()
 
 #define FREERDS_LIBRARY_NAME "libfreerds-fdsapi.so"
 
-void InitializeWtsApiStubs_FreeRDS()
+static void InitializeWtsApiStubs_FreeRDS()
 {
        wIniFile* ini;
        const char* prefix;
@@ -741,12 +784,9 @@ void InitializeWtsApiStubs_FreeRDS()
        IniFile_Free(ini);
 }
 
-void InitializeWtsApiStubs(void)
-{
-       if (g_Initialized)
-               return;
 
-       g_Initialized = TRUE;
+static BOOL CALLBACK InitializeWtsApiStubs(PINIT_ONCE once, PVOID param, PVOID *context)
+{
        InitializeWtsApiStubs_Env();
 
 #ifdef _WIN32
@@ -756,5 +796,5 @@ void InitializeWtsApiStubs(void)
        if (!g_WtsApi)
                InitializeWtsApiStubs_FreeRDS();
 
-       return;
+       return TRUE;
 }
diff --git a/winpr/libwinpr/wtsapi/wtsapi.h b/winpr/libwinpr/wtsapi/wtsapi.h
deleted file mode 100644 (file)
index d83ce20..0000000
+++ /dev/null
@@ -1,44 +0,0 @@
-/**
- * WinPR: Windows Portable Runtime
- * Windows Terminal Services API
- *
- * Copyright 2013 Marc-Andre Moreau <marcandre.moreau@gmail.com>
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef WINPR_WTSAPI_PRIVATE_H
-#define WINPR_WTSAPI_PRIVATE_H
-
-#define WTSAPI_STUB_CALL_VOID(_name, ...) \
-       if (!g_Initialized) \
-               InitializeWtsApiStubs(); \
-       if (!g_WtsApi || !g_WtsApi->p ## _name) \
-               return; \
-       g_WtsApi->p ## _name ( __VA_ARGS__ )
-
-#define WTSAPI_STUB_CALL_BOOL(_name, ...) \
-       if (!g_Initialized) \
-               InitializeWtsApiStubs(); \
-       if (!g_WtsApi || !g_WtsApi->p ## _name) \
-               return FALSE; \
-       return g_WtsApi->p ## _name ( __VA_ARGS__ )
-
-#define WTSAPI_STUB_CALL_HANDLE(_name, ...) \
-       if (!g_Initialized) \
-               InitializeWtsApiStubs(); \
-       if (!g_WtsApi || !g_WtsApi->p ## _name) \
-               return NULL; \
-       return g_WtsApi->p ## _name ( __VA_ARGS__ )
-
-#endif /* WINPR_WTSAPI_PRIVATE_H */
index d6d708c..2a2556b 100644 (file)
@@ -30,7 +30,6 @@
 
 #include "wtsapi_win32.h"
 
-#include "wtsapi.h"
 #include "../log.h"
 
 #define WTSAPI_CHANNEL_MAGIC   0x44484356
@@ -72,6 +71,37 @@ static fnWinStationVirtualOpenEx pfnWinStationVirtualOpenEx = NULL;
 
 BOOL WINAPI Win32_WTSVirtualChannelClose(HANDLE hChannel);
 
+
+/**
+  * NOTE !!
+  * An application using the WinPR wtsapi frees memory via WTSFreeMemory, which
+  * might be mapped to Win32_WTSFreeMemory. Latter does not know if the passed
+  * pointer was allocated by a function in wtsapi32.dll or in some internal
+  * code below. The WTSFreeMemory implementation in all Windows wtsapi32.dll
+  * versions up to Windows 10 uses LocalFree since all its allocating functions
+  * use LocalAlloc() internally.
+  * For that reason we also have to use LocalAlloc() for any memory returned by
+  * our WinPR wtsapi functions.
+  *
+  * To be safe we only use the _wts_malloc, _wts_calloc, _wts_free wrappers
+  * for memory managment the code below.
+  */
+
+static void *_wts_malloc(size_t size)
+{
+       return (PVOID)LocalAlloc(LMEM_FIXED, size);
+}
+
+static void *_wts_calloc(size_t nmemb, size_t size)
+{
+       return (PVOID)LocalAlloc(LMEM_FIXED | LMEM_ZEROINIT, nmemb * size);
+}
+
+static void _wts_free(void* ptr)
+{
+       LocalFree((HLOCAL)ptr);
+}
+
 BOOL Win32_WTSVirtualChannelReadAsync(WTSAPI_CHANNEL* pChannel)
 {
        BOOL status = TRUE;
@@ -130,19 +160,28 @@ HANDLE WINAPI Win32_WTSVirtualChannelOpen_Internal(HANDLE hServer, DWORD Session
        HANDLE hFile;
        HANDLE hChannel;
        WTSAPI_CHANNEL* pChannel;
+       size_t virtualNameLen;
+
+       virtualNameLen = pVirtualName ? strlen(pVirtualName) : 0;
 
-       if (!pVirtualName)
+       if (!virtualNameLen)
        {
                SetLastError(ERROR_INVALID_PARAMETER);
                return NULL;
        }
 
+       if (!pfnWinStationVirtualOpenEx)
+       {
+               SetLastError(ERROR_INVALID_FUNCTION);
+               return NULL;
+       }
+
        hFile = pfnWinStationVirtualOpenEx(hServer, SessionId, pVirtualName, flags);
 
        if (!hFile)
                return NULL;
 
-       pChannel = (WTSAPI_CHANNEL*) calloc(1, sizeof(WTSAPI_CHANNEL));
+       pChannel = (WTSAPI_CHANNEL*) _wts_calloc(1, sizeof(WTSAPI_CHANNEL));
 
        if (!pChannel)
        {
@@ -156,14 +195,15 @@ HANDLE WINAPI Win32_WTSVirtualChannelOpen_Internal(HANDLE hServer, DWORD Session
        pChannel->hServer = hServer;
        pChannel->SessionId = SessionId;
        pChannel->hFile = hFile;
-       pChannel->VirtualName = _strdup(pVirtualName);
+       pChannel->VirtualName = _wts_calloc(1, virtualNameLen + 1);
        if (!pChannel->VirtualName)
        {
                CloseHandle(hFile);
                SetLastError(ERROR_NOT_ENOUGH_MEMORY);
-               free(pChannel);
+               _wts_free(pChannel);
                return NULL;
        }
+       memcpy(pChannel->VirtualName, pVirtualName, virtualNameLen);
 
        pChannel->flags = flags;
        pChannel->dynamic = (flags & WTS_CHANNEL_OPTION_DYNAMIC) ? TRUE : FALSE;
@@ -171,7 +211,7 @@ HANDLE WINAPI Win32_WTSVirtualChannelOpen_Internal(HANDLE hServer, DWORD Session
        pChannel->showProtocol = pChannel->dynamic;
 
        pChannel->readSize = CHANNEL_PDU_LENGTH;
-       pChannel->readBuffer = (BYTE*) malloc(pChannel->readSize);
+       pChannel->readBuffer = (BYTE*) _wts_malloc(pChannel->readSize);
 
        pChannel->header = (CHANNEL_PDU_HEADER*) pChannel->readBuffer;
        pChannel->chunk = &(pChannel->readBuffer[sizeof(CHANNEL_PDU_HEADER)]);
@@ -230,18 +270,18 @@ BOOL WINAPI Win32_WTSVirtualChannelClose(HANDLE hChannel)
 
        if (pChannel->VirtualName)
        {
-               free(pChannel->VirtualName);
+               _wts_free(pChannel->VirtualName);
                pChannel->VirtualName = NULL;
        }
 
        if (pChannel->readBuffer)
        {
-               free(pChannel->readBuffer);
+               _wts_free(pChannel->readBuffer);
                pChannel->readBuffer = NULL;
        }
 
        pChannel->magic = 0;
-       free(pChannel);
+       _wts_free(pChannel);
 
        return status;
 }
@@ -660,7 +700,7 @@ BOOL WINAPI Win32_WTSVirtualChannelQuery(HANDLE hChannelHandle, WTS_VIRTUAL_CLAS
        else if (WtsVirtualClass == WTSVirtualFileHandle)
        {
                *pBytesReturned = sizeof(HANDLE);
-               *ppBuffer = calloc(1, *pBytesReturned);
+               *ppBuffer = _wts_calloc(1, *pBytesReturned);
 
                if (*ppBuffer == NULL)
                {
@@ -673,7 +713,7 @@ BOOL WINAPI Win32_WTSVirtualChannelQuery(HANDLE hChannelHandle, WTS_VIRTUAL_CLAS
        else if (WtsVirtualClass == WTSVirtualEventHandle)
        {
                *pBytesReturned = sizeof(HANDLE);
-               *ppBuffer = calloc(1, *pBytesReturned);
+               *ppBuffer = _wts_calloc(1, *pBytesReturned);
 
                if (*ppBuffer == NULL)
                {
@@ -697,12 +737,12 @@ BOOL WINAPI Win32_WTSVirtualChannelQuery(HANDLE hChannelHandle, WTS_VIRTUAL_CLAS
 
 VOID WINAPI Win32_WTSFreeMemory(PVOID pMemory)
 {
-       free(pMemory);
+       _wts_free(pMemory);
 }
 
 BOOL WINAPI Win32_WTSFreeMemoryExW(WTS_TYPE_CLASS WTSTypeClass, PVOID pMemory, ULONG NumberOfEntries)
 {
-       return TRUE;
+       return FALSE;
 }
 
 BOOL WINAPI Win32_WTSFreeMemoryExA(WTS_TYPE_CLASS WTSTypeClass, PVOID pMemory, ULONG NumberOfEntries)
@@ -710,18 +750,18 @@ BOOL WINAPI Win32_WTSFreeMemoryExA(WTS_TYPE_CLASS WTSTypeClass, PVOID pMemory, U
        return WTSFreeMemoryExW(WTSTypeClass, pMemory, NumberOfEntries);
 }
 
-int Win32_InitializeWinSta(PWtsApiFunctionTable pWtsApi)
+BOOL Win32_InitializeWinSta(PWtsApiFunctionTable pWtsApi)
 {
        g_WinStaModule = LoadLibraryA("winsta.dll");
 
        if (!g_WinStaModule)
-               return -1;
+               return FALSE;
 
        pfnWinStationVirtualOpen = (fnWinStationVirtualOpen) GetProcAddress(g_WinStaModule, "WinStationVirtualOpen");
        pfnWinStationVirtualOpenEx = (fnWinStationVirtualOpenEx) GetProcAddress(g_WinStaModule, "WinStationVirtualOpenEx");
 
-       if (!pfnWinStationVirtualOpenEx)
-               return -1;
+       if (!pfnWinStationVirtualOpen | !pfnWinStationVirtualOpenEx)
+               return FALSE;
 
        pWtsApi->pVirtualChannelOpen = Win32_WTSVirtualChannelOpen;
        pWtsApi->pVirtualChannelOpenEx = Win32_WTSVirtualChannelOpenEx;
@@ -735,5 +775,5 @@ int Win32_InitializeWinSta(PWtsApiFunctionTable pWtsApi)
        //pWtsApi->pFreeMemoryExW = Win32_WTSFreeMemoryExW;
        //pWtsApi->pFreeMemoryExA = Win32_WTSFreeMemoryExA;
 
-       return 1;
+       return TRUE;
 }
index c6add2f..7d43165 100644 (file)
@@ -22,6 +22,6 @@
 
 #include <winpr/wtsapi.h>
 
-int Win32_InitializeWinSta(PWtsApiFunctionTable pWtsApi);
+BOOL Win32_InitializeWinSta(PWtsApiFunctionTable pWtsApi);
 
 #endif /* WINPR_WTSAPI_WIN32_PRIVATE_H */