Add support for verify-ca/full sslmode using Windows Schannel Security
authorHiroshi Inoue <inoue@tpf.co.p>
Sat, 26 Oct 2013 11:45:39 +0000 (20:45 +0900)
committerHiroshi Inoue <inoue@tpf.co.jp>
Wed, 30 Oct 2013 12:24:32 +0000 (21:24 +0900)
Service Provider.
Root CAs must be installed into Windows Root certificate store beforehand.

connection.c
dlg_wingui.c
sspisvcs.c
sspisvcs.h

index 7e750a103ae384565170ba9895aea6889436c8f4..62393029e2b45f61c2d35b851189c5f7b3989543 100644 (file)
@@ -1567,6 +1567,7 @@ original_CC_connect(ConnectionClass *self, char password_req, char *salt_para)
 #ifdef USE_SSPI
    int ssl_try_count, ssl_try_no;
    char    ssl_call[2];
+   int bReconnect = 0;
 #endif /* USE_SSPI */
 
    mylog("%s: entering...\n", func);
@@ -1701,8 +1702,13 @@ inolog("protocol=%s version=%d,%d\n", ci->protocol, self->pg_version_major, self
            switch (rnego)
            {
                case 'S':
-                   if (!StartupSspiService(sock, SchannelService, NULL))
+                   if (!StartupSspiService(sock, SchannelService, ci, &bReconnect))
                    {
+                       if (bReconnect != 0)
+                       {
+                           anotherVersionRetry = TRUE;
+                           goto another_version_retry;
+                       }
                        CC_set_error(self, CONN_INVALID_AUTHENTICATION, "Service negotation failed", func);
                        goto error_proc;
                    }
@@ -1941,7 +1947,7 @@ inolog("Ekita retry=%d\n", retry);
                            if (!ci->gssauth_use_gssapi)
                            {
                                self->auth_svcs = KerberosService;
-                               authRet = StartupSspiService(sock, self->auth_svcs, MakePrincHint(ci, TRUE));
+                               authRet = StartupSspiService(sock, self->auth_svcs, MakePrincHint(ci, TRUE), &bReconnect);
                                if (!authRet)
                                {
                                    CC_set_error(self, CONN_INVALID_AUTHENTICATION, "Service negotation failed", func);
@@ -2002,7 +2008,7 @@ inolog("Ekita retry=%d\n", retry);
                            mylog("in AUTH_REQ_SSPI\n");
 #if    defined(USE_SSPI)
                            self->auth_svcs = ci->gssauth_use_gssapi ? KerberosService : NegotiateService;
-                           if (!StartupSspiService(sock, self->auth_svcs, MakePrincHint(ci, TRUE)))
+                           if (!StartupSspiService(sock, self->auth_svcs, MakePrincHint(ci, TRUE), &bReconnect))
                            {
                                CC_set_error(self, CONN_INVALID_AUTHENTICATION, "Service negotation failed", func);
                                goto error_proc;
index f38d856b9d046bc5a9b4799c0245ae82b2217240..090359a4fd6ba9120fbba54308c5fe9e18fc5d54 100644 (file)
@@ -83,7 +83,7 @@ mylog("SendMessage CTL_COLOR\n");
        SendMessage(GetDlgItem(hdlg, IDC_NOTICE_USER), WM_CTLCOLOR, 0, 0);
 #ifdef USE_SSPI
        ShowWindow(GetDlgItem(hdlg, IDC_NOTICE_USER), SW_HIDE);
-       dsplevel = 1;
+       dsplevel = 2;
 #endif /* USE_SSPI */
    }
 
index d446e85b3338bbe5aa6043f9a3599c0dc81aff79..010235b815f8e96a64825228e16f593e3bf25e11 100644 (file)
@@ -152,9 +152,9 @@ typedef struct {
    KerberosEtcSpec kdata;
 } SspiData;
 
-static int DoSchannelNegotiation(SocketClass *, SspiData *, const void *opt);
-static int DoKerberosNegotiation(SocketClass *, SspiData *, const void *opt);
-static int DoNegotiateNegotiation(SocketClass *, SspiData *, const void *opt);
+static int DoSchannelNegotiation(SocketClass *, SspiData *, const void *opt, int *bReconnect);
+static int DoKerberosNegotiation(SocketClass *, SspiData *, const void *opt, int *bReconnect);
+static int DoNegotiateNegotiation(SocketClass *, SspiData *, const void *opt, int *bReconnect);
 static int DoKerberosEtcProcessAuthentication(SocketClass *, const void *opt);
 
 static SspiData *SspiDataAlloc(SocketClass *self)
@@ -166,21 +166,23 @@ static SspiData *SspiDataAlloc(SocketClass *self)
    return sspidata;
 }
 
-int StartupSspiService(SocketClass *self, SSPI_Service svc, const void *opt)
+int StartupSspiService(SocketClass *self, SSPI_Service svc, const void *opt, int *bReconnect)
 {
    CSTR func = "DoServicelNegotiation";
    SspiData    *sspidata;
 
+   if (bReconnect != NULL)
+       *bReconnect = 0;
    if (NULL == (sspidata = SspiDataAlloc(self)))
        return -1;
    switch (svc)
    {
        case SchannelService:
-           return DoSchannelNegotiation(self, sspidata, opt);
+           return DoSchannelNegotiation(self, sspidata, opt, bReconnect);
        case KerberosService:
-           return DoKerberosNegotiation(self, sspidata, opt);
+           return DoKerberosNegotiation(self, sspidata, opt, bReconnect);
        case NegotiateService:
-           return DoNegotiateNegotiation(self, sspidata, opt);
+           return DoNegotiateNegotiation(self, sspidata, opt, bReconnect);
    }
 
    free(sspidata);
@@ -252,6 +254,7 @@ static SECURITY_STATUS PerformSchannelClientHandshake(SOCKET, PCredHandle, LPSTR
 static SECURITY_STATUS SchannelClientHandshakeLoop(SOCKET, PCredHandle, CtxtHandle *, BOOL, SecBuffer *);
 static void GetNewSchannelClientCredentials(PCredHandle, CtxtHandle *);
 
+static BOOL        bRootCALoaded = FALSE;
 static BOOL        bMyCert = FALSE;
 static HCERTSTORE  hMyCertStore  = NULL;
 static HCRYPTPROV  hProv = (HCRYPTPROV) 0;
@@ -282,6 +285,7 @@ void LeaveSSPIService()
 {
    FreeCertStores();
    bMyCert = FALSE;
+   bRootCALoaded = FALSE;
 }
 
 /*
@@ -431,6 +435,8 @@ GetLastError());
        mylog("CryptImportKey failed with error 0x%.8X\n", GetLastError());
        goto cleanup;
    }
+   CryptDestroyKey(hKey);
+   hKey = (HCRYPTKEY) 0;
    free(pbKeyBlob);
    pbKeyBlob = NULL;
 
@@ -473,8 +479,6 @@ cleanup:
        free(pbKeyBlob);
    if (fd != INVALID_HANDLE_VALUE)
        CloseHandle(fd);
-   if (hKey)
-       CryptDestroyKey(hKey);
    if (!success)
    {
        HCERTSTORE  hSv = hMyCertStore;
@@ -490,7 +494,89 @@ cleanup:
    return;
 }
 
-static int DoSchannelNegotiation(SocketClass *self, SspiData *sspidata, const void *opt)
+static int InstallRootCA()
+{
+   HCERTSTORE  hTempCertStore = NULL, hRootCertStore = NULL;
+   LPCTSTR pgsslroot = NULL;
+   LPCTSTR appdata = NULL;
+   TCHAR   sslroot[256];
+   PCCERT_CONTEXT  pContext = NULL;
+   int installed_count = 0, reject_count = 0;
+
+   shortterm_common_lock();
+   if (bRootCALoaded)
+   {   
+       shortterm_common_unlock();
+       goto cleanup;
+   }
+   shortterm_common_unlock();
+
+   pgsslroot = getenv("PGSSLROOTCERT");
+   if (!pgsslroot)
+   {
+       appdata = getenv("APPDATA");
+       if (!appdata)           goto cleanup;
+       snprintf(sslroot, sizeof(sslroot), "%s\\postgresql\\root.crt", appdata);
+       pgsslroot = sslroot;
+   }
+
+   hTempCertStore = CertOpenStore(
+           CERT_STORE_PROV_FILENAME_A
+           , 0
+           , (HCRYPTPROV) NULL
+           , CERT_STORE_OPEN_EXISTING_FLAG | CERT_STORE_READONLY_FLAG
+           , pgsslroot
+           );
+   if (hTempCertStore == NULL)
+       goto cleanup;
+   hRootCertStore = CertOpenSystemStore(0, TEXT("ROOT"));
+   mylog("hRootCertStore=%p sslroot=%s\n", hRootCertStore, pgsslroot);
+   if (hRootCertStore == NULL)
+       goto cleanup;
+
+   pContext = CertEnumCertificatesInStore(hTempCertStore, pContext);
+   while (pContext != NULL)
+   {
+       if (CertAddCertificateContextToStore(hRootCertStore, pContext, CERT_STORE_ADD_NEWER, NULL))
+           installed_count++;
+       else
+       {
+           int lasterror = GetLastError();
+           switch (lasterror)
+           {
+               case  CRYPT_E_EXISTS:
+                   mylog("Certificate already exists\n");
+                   break;
+               case  1223: // ERROR_CANCELED
+                   reject_count++;
+                   mylog("Certificate canceled\n");
+                   break;
+               default:
+                   reject_count++;
+                   mylog("Failed to install root certificate error=%08x\n", lasterror);
+           }
+       }
+       pContext = CertEnumCertificatesInStore(hRootCertStore, pContext);
+   }
+   shortterm_common_lock();
+   if (!bRootCALoaded)
+   {
+       if (installed_count > 0 ||
+               reject_count == 0)
+           bRootCALoaded = TRUE;
+   }
+   shortterm_common_unlock();
+
+cleanup:
+   if (hTempCertStore)
+       CertCloseStore(hTempCertStore, CERT_CLOSE_STORE_FORCE_FLAG);
+   if (hRootCertStore)
+       CertCloseStore(hRootCertStore, CERT_CLOSE_STORE_FORCE_FLAG);
+
+   return installed_count;
+}
+
+static int DoSchannelNegotiation(SocketClass *self, SspiData *sspidata, const void *opt, int *bReconnect)
 {
    CSTR func = "DoSchannelNegotiation";
    SECURITY_STATUS r = SEC_E_OK;
@@ -498,17 +584,30 @@ static int DoSchannelNegotiation(SocketClass *self, SspiData *sspidata, const vo
    SecBuffer   ExtraData;
    BOOL        ret = 0, cCreds = FALSE, cCtxt = FALSE;
    SchannelSpec    *ssd = &(sspidata->sdata);
+   char        *server = NULL;
 
-   if (SEC_E_OK != (r = CreateSchannelCredentials(NULL, NULL, &ssd->hCred)))
+   if (SEC_E_OK != (r = CreateSchannelCredentials(opt, NULL, &ssd->hCred)))
    {
        cmd = "CreateSchannelCredentials";
        mylog("%s:%s failed\n", func, cmd);
        goto cleanup;
    }
    cCreds = TRUE;
-   if (SEC_E_OK != (r = PerformSchannelClientHandshake(self->socket, &ssd->hCred, NULL, &ssd->hCtxt, &ExtraData)))
+   if (opt != NULL)
+       server = ((ConnInfo *) opt)->server;
+   if (SEC_E_OK != (r = PerformSchannelClientHandshake(self->socket, &ssd->hCred, server, &ssd->hCtxt, &ExtraData)))
    {
        cmd = "PerformSchannelClientHandshake";
+       switch (r)
+       {
+           case SEC_E_UNTRUSTED_ROOT:
+               mylog("Installing RootCA\n");
+               if (InstallRootCA() > 0)
+                   *bReconnect = 1;
+               break;
+           default:
+               break;
+       }
        mylog("%s:%s failed\n", func, cmd);
        goto cleanup;
    }
@@ -558,6 +657,7 @@ CreateSchannelCredentials(
    ALG_ID      rgbSupportedAlgs[16];
    DWORD       dwProtocol = SP_PROT_SSL3 | SP_PROT_SSL2;
    DWORD       aiKeyExch = 0;
+   char        *sslmode = NULL;
 
    PCCERT_CONTEXT  pCertContext = NULL;
 
@@ -628,7 +728,16 @@ CreateSchannelCredentials(
     * leave off this flag, in which case the InitializeSecurityContext
     * function will validate the server certificate automatically.
     */
-   SchannelCred.dwFlags |= SCH_CRED_MANUAL_CRED_VALIDATION;
+   if (opt != NULL)
+       sslmode = ((ConnInfo *) opt)->sslmode;
+   if (sslmode == NULL || sslmode[0] != 'v')
+       SchannelCred.dwFlags |= SCH_CRED_MANUAL_CRED_VALIDATION;
+   else
+   {
+       if (strcmp(sslmode, "verify-full"))
+           SchannelCred.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK;
+       // InstallRootCA();
+   }
 
    /*
     * Create an SSPI credential.
@@ -1146,7 +1255,7 @@ GetNewSchannelClientCredentials(
 static SECURITY_STATUS CreateKerberosEtcCredentials(LPCTSTR, SEC_CHAR *, LPCTSTR, PCredHandle);
 static SECURITY_STATUS PerformKerberosEtcClientHandshake(SocketClass *, KerberosEtcSpec *ssd, size_t);
 
-static int DoKerberosNegotiation(SocketClass *self, SspiData *sspidata, const void *opt)
+static int DoKerberosNegotiation(SocketClass *self, SspiData *sspidata, const void *opt, int *bReconnect)
 {
    CSTR func = "DoKerberosNegotiation";
    SECURITY_STATUS r = SEC_E_OK;
@@ -1170,7 +1279,7 @@ mylog("!!! CreateKerberosCredentials passed\n");
    return DoKerberosEtcProcessAuthentication(self, NULL);
 }
 
-static int DoNegotiateNegotiation(SocketClass *self, SspiData *sspidata, const void *opt)
+static int DoNegotiateNegotiation(SocketClass *self, SspiData *sspidata, const void *opt, int *bReconnect)
 {
    CSTR func = "DoNegotiateNegotiation";
    SECURITY_STATUS r = SEC_E_OK;
index 99b37b7d886408f9202e662aa712e19bb777200a..07542f0722895496fd29d31c00263261baace665 100755 (executable)
@@ -20,7 +20,7 @@ typedef enum {
 
 void   LeaveSSPIService();
 void   ReleaseSvcSpecData(SocketClass *self, UInt4);
-int    StartupSspiService(SocketClass *self, SSPI_Service svc, const void *opt);
+int    StartupSspiService(SocketClass *self, SSPI_Service svc, const void *opt, int *bReconnect);
 int    ContinueSspiService(SocketClass *self, SSPI_Service svc, const void *opt);
 int    SSPI_recv(SocketClass *self, void *buf, int len);
 int    SSPI_send(SocketClass *self, const void *buf, int len);