OK I found the way to handle private keys of PEM form using CryptoAPI.
authorHiroshi Inoue <inoue@tpf.co.p>
Thu, 24 Oct 2013 08:49:58 +0000 (17:49 +0900)
committerHiroshi Inoue <inoue@tpf.co.p>
Sat, 26 Oct 2013 11:05:32 +0000 (20:05 +0900)
Certificates of PFX form are no longer needed for SSL client certificate
authntication.

sspisvcs.c

index 232da1160c76a44e1afab0960825de1c6d1e9c2c..d446e85b3338bbe5aa6043f9a3599c0dc81aff79 100644 (file)
@@ -1,5 +1,5 @@
 /*-------
- * Module:         sspi_proc.c
+ * Module:         sspisvcs.c
  *
  * Description:        This module contains functions for low level socket
  *                 operations (connecting/reading/writing to the backend)
@@ -22,6 +22,8 @@
 
 #include "sspisvcs.h"
 #include "socket.h"
+#include "connection.h"
+#include "environ.h"
 
 /*
  * To handle EWOULDBLOCK etc (mainly for libpq non-blocking connection).
@@ -252,15 +254,28 @@ static void GetNewSchannelClientCredentials(PCredHandle, CtxtHandle *);
 
 static BOOL        bMyCert = FALSE;
 static HCERTSTORE  hMyCertStore  = NULL;
-static const DWORD propId = CERT_KEY_PROV_INFO_PROP_ID;
+static HCRYPTPROV  hProv = (HCRYPTPROV) 0;
+static PCCERT_CONTEXT  pClientCertContext = NULL;
 
 static void FreeCertStores()
 {
+   shortterm_common_lock();
+   if (pClientCertContext)
+   {
+           CertFreeCertificateContext(pClientCertContext);
+       pClientCertContext = NULL;
+   }
+   if (hProv)
+   {
+       CryptReleaseContext(hProv, 0);
+       hProv = (HCRYPTPROV) 0;
+   }
    if (hMyCertStore)
    {
        CertCloseStore(hMyCertStore, CERT_CLOSE_STORE_FORCE_FLAG);
        hMyCertStore = NULL;
    }
+   shortterm_common_unlock();
 }
 
 void LeaveSSPIService()
@@ -269,16 +284,20 @@ void LeaveSSPIService()
    bMyCert = FALSE;
 }
 
-static void CertStoreInit()
+/*
+ * This driver allows certificates of PFX form when a pair of
+ * postgresql.crt and postgresql.key doesn't work well. 
+ */
+static void CertStoreInit_pfx()
 {
+   BOOL    success = FALSE;
    LPCTSTR pgsslpfx = NULL;
    LPCTSTR appdata = NULL;
    TCHAR   sslpfx[256];
-   HANDLE  fd;
+   HANDLE  fd = INVALID_HANDLE_VALUE;
    DWORD   flen, rlen;
    CRYPT_DATA_BLOB crypt_data;
 
-
    if (bMyCert)            return;
    if (hMyCertStore != NULL)   return;
 
@@ -297,29 +316,178 @@ static void CertStoreInit()
    if (INVALID_HANDLE_VALUE == fd)
    {
        mylog("!!! pfxfile=%s not found\n", pgsslpfx);
-       FreeCertStores();
-       return;
+       goto cleanup;
    }
    flen = GetFileSize(fd, NULL);
    if (flen <= 0)
    {
-       FreeCertStores();
-       CloseHandle(fd);
-       return;
+       goto cleanup;
    }
    crypt_data.cbData = flen;
    crypt_data.pbData = (LPBYTE) CryptMemAlloc(flen);
    ReadFile(fd, crypt_data.pbData, flen, &rlen, NULL);
-   CloseHandle(fd);  
+   CloseHandle(fd);
+   fd = INVALID_HANDLE_VALUE;  
    hMyCertStore = PFXImportCertStore(&crypt_data, L"", 0);
    CryptMemFree(crypt_data.pbData);
+
+   success = TRUE;
+cleanup:
+   if (fd != INVALID_HANDLE_VALUE)
+       CloseHandle(fd);
+   if (!success)
+       FreeCertStores();
+}
+
+static void CertStoreInit()
+{
+   BOOL    success = FALSE;
+   LPCTSTR pgsslkey = NULL, pgsslcert = NULL;
+   LPCTSTR appdata = NULL;
+   TCHAR   sslkey[256], sslcert[256];
+   HANDLE  fd = INVALID_HANDLE_VALUE;
+   DWORD   flen, rlen;
+   char    *pemdata = NULL;
+   DWORD   dwBufferLen, cbKeyBlob;
+   LPBYTE  pbBuffer = NULL, pbKeyBlob = NULL;
+   HCRYPTKEY   hKey = (HCRYPTKEY) 0;
+
+   if (hMyCertStore != NULL)   return;
+
+   bMyCert = TRUE; 
+
+   pgsslkey = getenv("PGSSLKEY");
+   if (!pgsslkey)
+   {
+       if (!appdata)
+           appdata = getenv("APPDATA");
+       if (!appdata)           goto cleanup;
+       snprintf(sslkey, sizeof(sslkey), "%s\\postgresql\\postgresql.key", appdata);
+       pgsslkey = sslkey;
+   }
+
+   fd = CreateFile(pgsslkey, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
+   if (INVALID_HANDLE_VALUE == fd)
+   {
+       mylog("!!! keyfile=%s not found\n", pgsslkey);
+       goto cleanup;
+   }
+   flen = GetFileSize(fd, NULL);
+   if (flen <= 0)
+   {
+       goto cleanup;
+   }
+   if (pemdata = malloc(flen), NULL == pemdata)
+       goto cleanup;
+   ReadFile(fd, pemdata, flen, &rlen, NULL);
+   CloseHandle(fd);
+   fd = INVALID_HANDLE_VALUE;
+  
+   if (!CryptStringToBinaryA(pemdata, 0, CRYPT_STRING_BASE64HEADER,
+       NULL, &dwBufferLen, NULL, NULL))
+   {
+       mylog("Failed to convert BASE64 private key. Error 0x%.8X\n",
+GetLastError());
+       goto cleanup;
+   }
+   if (pbBuffer = malloc(dwBufferLen), NULL == pbBuffer)
+       goto cleanup;
+   if (!CryptStringToBinaryA(pemdata, 0, CRYPT_STRING_BASE64HEADER,
+       pbBuffer, &dwBufferLen, NULL, NULL))
+   {
+       mylog("Failed to convert BASE64 private key. Error 0x%.8X\n", GetLastError());
+       goto cleanup;
+   }
+   free(pemdata);
+   pemdata = NULL;
+
+   if (!CryptDecodeObjectEx(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING,
+       PKCS_RSA_PRIVATE_KEY, pbBuffer, dwBufferLen, 0, NULL, NULL, &cbKeyBlob))
+   {
+       mylog("Failed to parse private key. Error 0x%.8X\n", GetLastError());
+       goto cleanup;
+   }
+   if (pbKeyBlob = malloc(cbKeyBlob), NULL == pbKeyBlob)
+       goto cleanup;
+   if (!CryptDecodeObjectEx(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING,
+       PKCS_RSA_PRIVATE_KEY, pbBuffer, dwBufferLen, 0, NULL, pbKeyBlob, &cbKeyBlob))
+   {
+       mylog("Failed to parse private key. Error 0x%.8X\n", GetLastError());
+       goto cleanup;
+   }
+   free(pbBuffer);
+   pbBuffer = NULL;
+
+   // Create a temporary and volatile CSP context in order to import
+   // the key
+   if (!CryptAcquireContext(&hProv, NULL, MS_ENHANCED_PROV, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT))
+   {
+       mylog("CryptAcquireContext failed with error 0x%.8X\n", GetLastError());
+       goto cleanup;
+   }
+   // import private key
+   if (!CryptImportKey(hProv, pbKeyBlob, cbKeyBlob, (HCRYPTKEY) 0, 0, &hKey))
+   {
+       mylog("CryptImportKey failed with error 0x%.8X\n", GetLastError());
+       goto cleanup;
+   }
+   free(pbKeyBlob);
+   pbKeyBlob = NULL;
+
+   pgsslcert = getenv("PGSSLCERT");
+   if (!pgsslcert)
+   {
+       appdata = getenv("APPDATA");
+       if (!appdata)           goto cleanup;
+       snprintf(sslcert, sizeof(sslcert), "%s\\postgresql\\postgresql.crt", appdata);
+       pgsslcert = sslcert;
+   }
+
+   hMyCertStore = CertOpenStore(
+           CERT_STORE_PROV_FILENAME_A
+           , 0
+           , (HCRYPTPROV) NULL
+           , CERT_STORE_OPEN_EXISTING_FLAG | CERT_STORE_READONLY_FLAG
+           , pgsslcert
+           );
+mylog("!!! hMyCertStore=%p sslcert=%s sslkey=%s\n", hMyCertStore, pgsslcert, pgsslkey);
+   if (hMyCertStore != NULL)
+   {
+       PCCERT_CONTEXT  pContext = NULL;
+
+       pContext = CertEnumCertificatesInStore(hMyCertStore, pContext);
+       while (pContext != NULL)
+       {
+           CertSetCertificateContextProperty(pContext, CERT_KEY_PROV_HANDLE_PROP_ID, 0, (const void *) hProv);
+           pContext = CertEnumCertificatesInStore(hMyCertStore, pContext);
+       }
+   }
+
+   success = TRUE;
+cleanup:
+   if (pemdata)
+       free(pemdata);
+   if (pbBuffer)
+       free(pbBuffer);
+   if (pbKeyBlob)
+       free(pbKeyBlob);
+   if (fd != INVALID_HANDLE_VALUE)
+       CloseHandle(fd);
+   if (hKey)
+       CryptDestroyKey(hKey);
+   if (!success)
+   {
+       HCERTSTORE  hSv = hMyCertStore;
+
+       FreeCertStores();
+       if (hSv == NULL)
+           CertStoreInit_pfx();
+   }
    if (!hMyCertStore)
    {
        mylog("!!! hMyCertStore=%p %d\n", hMyCertStore, GetLastError());
-       FreeCertStores();
-       return;
    }
-
+   return;
 }
 
 static int DoSchannelNegotiation(SocketClass *self, SspiData *sspidata, const void *opt)
@@ -398,7 +566,9 @@ CreateSchannelCredentials(
     * certificate. Otherwise, just create a NULL credential.
     */
 
-   if (pszUserName)
+   if (pClientCertContext)
+       pCertContext = pClientCertContext;
+   else if (pszUserName)
    {
        /* Find client certificate. Note that this sample just searchs for a 
         * certificate that contains the user name somewhere in the subject name.
@@ -482,17 +652,15 @@ CreateSchannelCredentials(
 
 cleanup:
 
-    /*
-     * Free the certificate context. Schannel has already made its own copy.
-     */
-
-    if(pCertContext)
-    {
-        CertFreeCertificateContext(pCertContext);
-    }
-
+   /*
+    * Free the certificate context. Schannel has already made its own copy.
+    */
+   if (pCertContext && pCertContext != pClientCertContext)
+   {
+       CertFreeCertificateContext(pCertContext);
+   }
 
-    return Status;
+   return Status;
 }
 
 static
@@ -848,7 +1016,9 @@ SchannelClientHandshakeLoop(
             */
             
            mylog("Server returned SEC_I_INCOMPLETE_CREDENTIALS\n");
+           shortterm_common_lock();
            GetNewSchannelClientCredentials(phCreds, phContext);
+           shortterm_common_unlock();
 
            /* Go around again. */
            fDoRead = FALSE;
@@ -892,8 +1062,6 @@ GetNewSchannelClientCredentials(
 {
    SCHANNEL_CRED   SchannelCred;
    CredHandle  hCreds;
-   SecPkgContext_IssuerListInfoEx  IssuerListInfo;
-   CERT_CHAIN_FIND_BY_ISSUER_PARA  FindByIssuerPara;
    PCCERT_CONTEXT  pCertContext;
    TimeStamp   tsExpiry;
    SECURITY_STATUS Status;
@@ -901,33 +1069,10 @@ GetNewSchannelClientCredentials(
    CertStoreInit();
    if (hMyCertStore == NULL)
        return;
-   /*
-    * Read list of trusted issuers from schannel.
-    */
-
-   Status = QueryContextAttributes(phContext,
-                   SECPKG_ATTR_ISSUER_LIST_EX,
-                   (PVOID)&IssuerListInfo);
-   if (Status != SEC_E_OK)
-   {
-       mylog("Error 0x%p querying issuer list info\n", Status);
-       return;
-   }
-
    /*
     * Enumerate the client certificates.
     */
-
-   ZeroMemory(&FindByIssuerPara, sizeof(FindByIssuerPara));
-
-   FindByIssuerPara.cbSize = sizeof(FindByIssuerPara);
-   FindByIssuerPara.pszUsageIdentifier = szOID_PKIX_KP_CLIENT_AUTH;
-   FindByIssuerPara.dwKeySpec  = 0;
-   FindByIssuerPara.cIssuer    = IssuerListInfo.cIssuers;
-   FindByIssuerPara.rgIssuer   = IssuerListInfo.aIssuers;
-
    pCertContext = NULL;
-
    while (TRUE)
    {
        pCertContext = CertEnumCertificatesInStore(hMyCertStore,
@@ -960,7 +1105,7 @@ GetNewSchannelClientCredentials(
             mylog("**** Error 0x%p returned by AcquireCredentialsHandle\n", Status);
            continue;
        }
-       mylog("new schannel credential created\n");
+       mylog("new schannel client credential created\n");
 
        /* Destroy the old credentials. */
        FreeCredentialsHandle(phCreds);
@@ -984,6 +1129,8 @@ GetNewSchannelClientCredentials(
         * the time is rather expensive.
         */
 
+       if (pCertContext)
+           pClientCertContext = pCertContext;
        break;
    }
 }