Main Page | Class Hierarchy | Class List | File List | Class Members | File Members

AsyncSecureSocketServer.h

Go to the documentation of this file.
00001 /*  SecureSocketServer.h - SSL communication through the Schannel API
00002     Copyright (C) 2001-2004 Mark Weaver
00003     Written by Mark Weaver <mark@npsl.co.uk>
00004 
00005     Part of the Open-Win32 library.
00006     This library is free software; you can redistribute it and/or
00007     modify it under the terms of the GNU Library General Public
00008     License as published by the Free Software Foundation; either
00009     version 2 of the License, or (at your option) any later version.
00010 
00011     This library is distributed in the hope that it will be useful,
00012     but WITHOUT ANY WARRANTY; without even the implied warranty of
00013     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00014     Library General Public License for more details.
00015 
00016     You should have received a copy of the GNU Library General Public
00017     License along with this library; if not, write to the
00018     Free Software Foundation, Inc., 59 Temple Place - Suite 330,
00019     Boston, MA  02111-1307, USA.
00020 */
00021 
00026 #ifndef OW32_AsyncSecureSocketServer_h
00027 #define OW32_AsyncSecureSocketServer_h
00028 
00029 #include <OW32/AsyncSecureSocket.h>
00030 #include <OW32/SSLInit.h>
00031 
00032 // Open Win32 namespace
00033 namespace OW32
00034 {
00035 
00037 template <class T>
00038 class CAsyncSecureSocketServer : public CAsyncSecureSocket<T>
00039 {
00040 private:
00041     // Prevent copy, assign
00042     CAsyncSecureSocketServer& operator= (const CAsyncSecureSocketServer& );
00043     CAsyncSecureSocketServer(const CAsyncSecureSocketServer& );
00044 
00045 public:
00049     CAsyncSecureSocketServer(CAsyncSocketCallback* pCallback) :
00050         CAsyncSecureSocket<T>(pCallback)
00051     {
00052     }
00053 
00058     CAsyncSecureSocketServer(CAsyncSocketCallback* pCallback, SOCKET s) :
00059         CAsyncSecureSocket<T>(pCallback, s)
00060     {
00061     }
00062 
00067     virtual int shutdown(int how);
00068 
00074     SECURITY_STATUS setServerCertificate(PCCERT_CONTEXT pCertContext,
00075         DWORD dwEnabledProtocols=0);
00076 
00087     static SECURITY_STATUS createCredentialsFromCertificate(
00088         CredHandle* phCreds, PCCERT_CONTEXT pCertContext, DWORD dwEnabledProtocols=0)
00089     {
00090         return CSecureSocket::createCredentialsFromCertificate(phCreds, pCertContext, 
00091             SECPKG_CRED_INBOUND, dwEnabledProtocols);
00092     }
00093 
00094 protected:
00097     virtual void negotiateLoop();
00098 
00099 private:
00100     void NegotiateError();
00101 
00102     void ReadCompletion(BOOL bRet, DWORD cbReceived);
00103     void SendCompletion(BOOL bRet, DWORD cbSent);
00104 
00105     virtual void onSendCompletion(BOOL bRet, DWORD cbBytesSent);
00106     virtual void onReadCompletion(BOOL bRet, DWORD cbBytesReceived);
00107     virtual void onTimeout();
00108 
00109     bool ProcessDecryptedData();
00110     int handshakeSend(const char* data, int length);
00111 
00112     auto_array_ptr<char> m_handshakeIo;
00113     DWORD m_cbHandshakeIo;
00114 
00115     SECURITY_STATUS m_scRet;
00116     SecBuffer m_InBuffers[2];
00117     SecBufferDesc m_InBuffer;
00118     SecBuffer m_OutBuffers[1];
00119     SecBufferDesc m_OutBuffer;
00120     int m_shutdownHow;
00121 
00122     enum { IO_BUFFER_SIZE = 16*1024 };
00123 };
00124 
00125 template <class T>
00126 void CAsyncSecureSocketServer<T>::NegotiateError()
00127 {
00128     switch (m_State)
00129     {
00130     case State_Negotiate:
00131         m_pCallback->onConnectCompletion(FALSE);
00132         break;
00133     case State_Renegotiate: // happening transparently during a read() call
00134     case State_Connected:
00135         m_pCallback->onReadCompletion(FALSE,0);
00136         break;
00137     case State_Shutdown:
00138         m_pCallback->onCloseCompletion(FALSE); // TODO: really close?
00139         break;
00140     }
00141 }
00142 
00143 template <class T>
00144 void CAsyncSecureSocketServer<T>::onTimeout()
00145 {
00146     if (m_State == State_Connected) {
00147         m_pCallback->onTimeout();
00148         return;
00149     }
00150     SetLastError(ERROR_TIMEOUT);
00151     NegotiateError();
00152 }
00153 
00154 template <class T>
00155 void CAsyncSecureSocketServer<T>::negotiateLoop()
00156 {
00157     // start by continuing...
00158     m_scRet = SEC_I_CONTINUE_NEEDED;
00159     
00160     // This is horrible.  Essentially, we don't /know/ how much data schannel
00161     // wants until it has told us so.  The absolute best we could do is have
00162     // a loop that goes around & allocates more buffer until we've got enough.
00163     // However, since we are dealing with SSL/PCT (at present) then we know
00164     // what the limit is.  So 16K will be enough for a single message.
00165     // We can't get the maximum message size until the connection is
00166     // established (sigh).  For the extra data, again, this could return more
00167     // than a maximum message if we over-read.  To cope with this, we /will/
00168     // allocate more extra buffer space than that; and never free it.  Not
00169     // really a possible case, but it will work in spite of it.
00170     // If only there was more information!
00171     m_handshakeIo.reset(new char[IO_BUFFER_SIZE]);
00172     m_cbHandshakeIo = 0;
00173 
00174     // If we have extra data, assume this is from a previous read
00175     // and use it
00176     if (m_ExtraCount > 0) 
00177     {
00178         // Too much extra data?  Oops, we're stuffed now!
00179         if (m_ExtraCount > IO_BUFFER_SIZE) 
00180         {
00181             SetLastError((DWORD)E_UNEXPECTED);
00182             NegotiateError();
00183             return;
00184         }
00185         CopyMemory(m_handshakeIo.get(), m_Extra.get(), m_ExtraCount);
00186         m_cbHandshakeIo = m_ExtraCount;
00187     }
00188     ReadCompletion(TRUE,0);
00189 }
00190 
00191 template <class T>
00192 int CAsyncSecureSocketServer<T>::handshakeSend(const char* data, int length)
00193 {
00194     m_sendData = data;
00195     m_sendLength = length;
00196     m_sendProcessed = 0;
00197     return T::send(data, length);
00198 }
00199 
00200 template <class T>
00201 void CAsyncSecureSocketServer<T>::onSendCompletion(BOOL bRet, DWORD cbSent)
00202 {
00203     if (m_State == State_Initial)
00204     {
00205         m_pCallback->onSendCompletion(bRet, cbSent);
00206         return;
00207     }
00208 
00209     if (m_State == State_Connected)
00210     {
00211         CAsyncSecureSocket<T>::onSendCompletion(bRet, cbSent);
00212         return;
00213     }
00214 
00215     // Pass on some kind of error or a connection closed 
00216     // notification directly on to our client
00217     if (!bRet || cbSent == 0) {
00218         if (cbSent == 0 && bRet) // TODO: hmmm
00219             SetLastError(ERROR_INVALID_FUNCTION);
00220         NegotiateError();
00221         return;
00222     }
00223     SendCompletion(bRet, cbSent);
00224 }
00225 
00226 template <class T>
00227 void CAsyncSecureSocketServer<T>::SendCompletion(BOOL bRet, DWORD cbSent)
00228 {
00229     bRet;
00230     m_sendProcessed += cbSent;
00231 
00232     if (m_sendProcessed < m_sendLength)
00233     {
00234         int ret = T::send(m_sendData + m_sendProcessed, m_sendLength - m_sendProcessed);
00235         if (ret != 0)
00236         {
00237             NegotiateError();
00238         }
00239         return;
00240     }
00241 
00242     if (m_sendData)
00243         g_SecurityFunc.FreeContextBuffer( (PVOID)m_sendData );
00244     m_sendData = NULL;
00245 
00246     switch (m_State)
00247     {
00248     case State_Negotiate:
00249     case State_Renegotiate:
00250         if (ProcessDecryptedData())
00251             ReadCompletion(TRUE, 0);
00252         break;
00253     case State_Shutdown:
00254         // The next thing to do is to actually call shutdown to close the socket
00255         if (CAsyncSecureSocket<T>::shutdown(m_shutdownHow) == SOCKET_ERROR)
00256             NegotiateError();
00257         break;
00258     }
00259 }
00260 
00261 template <class T>
00262 void CAsyncSecureSocketServer<T>::onReadCompletion(BOOL bRet, DWORD cbReceived)
00263 {
00264     if (m_State == State_Initial)
00265     {
00266         m_pCallback->onReadCompletion(bRet, cbReceived);
00267         return;
00268     }
00269 
00270     if (m_State == State_Connected)
00271     {
00272         CAsyncSecureSocket<T>::onReadCompletion(bRet, cbReceived);
00273         return;
00274     }
00275 
00276     // Pass on some kind of error or a connection closed 
00277     // notification  directly to our client
00278     if (!bRet || cbReceived == 0) {
00279         if (cbReceived == 0 && bRet) // TODO: hmmm
00280             SetLastError(ERROR_INVALID_FUNCTION);
00281         NegotiateError();
00282         return;
00283     }
00284     ReadCompletion(bRet, cbReceived);
00285 }
00286 
00287 template <class T>
00288 void CAsyncSecureSocketServer<T>::ReadCompletion(BOOL bRet, DWORD cbReceived)
00289 {
00290     bRet;
00291     m_cbHandshakeIo += cbReceived;
00292 
00293     while ( m_scRet == SEC_I_CONTINUE_NEEDED        ||
00294             m_scRet == SEC_E_INCOMPLETE_MESSAGE     ||
00295             m_scRet == SEC_I_INCOMPLETE_CREDENTIALS) 
00296     {
00297         // Read data if required
00298         if (0 == m_cbHandshakeIo || m_scRet == SEC_E_INCOMPLETE_MESSAGE)
00299         {
00300             // Too much data; give up
00301             if (m_cbHandshakeIo >= IO_BUFFER_SIZE) {
00302                 SetLastError((DWORD)E_UNEXPECTED);
00303                 NegotiateError();
00304                 return;
00305             }
00306 
00307             // So on re-entry we don't still have SEC_E_INCOMPLETE_MESSAGE
00308             // and read again!
00309             m_scRet = SEC_I_CONTINUE_NEEDED;
00310             int ret = T::recv(&m_handshakeIo[m_cbHandshakeIo], IO_BUFFER_SIZE-m_cbHandshakeIo);
00311             if (ret != 0)
00312                 NegotiateError();
00313             return;
00314         }
00315 
00316         //
00317         // InBuffers[1] is for getting extra data that
00318         //  SSPI/SCHANNEL doesn't proccess on this
00319         //  run around the loop.
00320         //
00321         m_InBuffers[0].pvBuffer = m_handshakeIo.get();
00322         m_InBuffers[0].cbBuffer = m_cbHandshakeIo ;
00323         m_InBuffers[0].BufferType = SECBUFFER_TOKEN;
00324 
00325         m_InBuffers[1].pvBuffer   = NULL;
00326         m_InBuffers[1].cbBuffer   = 0;
00327         m_InBuffers[1].BufferType = SECBUFFER_EMPTY;
00328 
00329         m_InBuffer.cBuffers        = 2;
00330         m_InBuffer.pBuffers        = m_InBuffers;
00331         m_InBuffer.ulVersion       = SECBUFFER_VERSION;
00332 
00333         //
00334         // Initialize these so if we fail, pvBuffer contains NULL,
00335         // so we don't try to free random garbage at the quit
00336         //
00337 
00338         //
00339         //  set OutBuffer for InitializeSecurityContext call
00340         //
00341         m_OutBuffer.cBuffers        = 1;
00342         m_OutBuffer.pBuffers        = m_OutBuffers;
00343         m_OutBuffer.ulVersion       = SECBUFFER_VERSION;
00344 
00345         m_OutBuffers[0].pvBuffer    = NULL;
00346         m_OutBuffers[0].BufferType  = SECBUFFER_TOKEN;
00347         m_OutBuffers[0].cbBuffer    = 0;
00348 
00349         DWORD dwSSPIOutFlags = 0;
00350         DWORD dwSSPIFlags =   
00351                 ASC_REQ_SEQUENCE_DETECT |
00352                 ASC_REQ_REPLAY_DETECT   |
00353                 ASC_REQ_CONFIDENTIALITY |
00354                 ASC_REQ_EXTENDED_ERROR  |
00355                 ASC_REQ_ALLOCATE_MEMORY |
00356                 ASC_REQ_STREAM |
00357                 (m_bRequireClientAuth ? ASC_REQ_MUTUAL_AUTH : 0);
00358         TimeStamp tsExpiry = {0};
00359 
00360         // Initial context if we don't already have one...
00361         bool fInitContext = !SecIsValidHandle(&m_hContext);
00362 
00363         // This is in here because it fucks up on Windows NT if
00364         // you feed it rubbish data (for example, via telnet).  Marvellous.
00365         __try 
00366         {
00367             m_scRet = g_SecurityFunc.AcceptSecurityContext(
00368                             &m_hCreds,
00369                             (fInitContext?NULL:&m_hContext),
00370                             &m_InBuffer,
00371                             dwSSPIFlags,
00372                             SECURITY_NATIVE_DREP,
00373                             (fInitContext?&m_hContext:NULL),
00374                             &m_OutBuffer,
00375                             &dwSSPIOutFlags,
00376                             &tsExpiry);
00377         } 
00378         __except(EXCEPTION_EXECUTE_HANDLER)
00379         {
00380             m_scRet = GetExceptionCode();
00381             if (!FAILED(m_scRet)) m_scRet = E_UNEXPECTED;
00382         }
00383 
00384         // Send some data if we have it
00385         if ( m_scRet == SEC_E_OK ||
00386              m_scRet == SEC_I_CONTINUE_NEEDED ||
00387              (FAILED(m_scRet) && (0 != (dwSSPIOutFlags & ASC_RET_EXTENDED_ERROR))))
00388         {
00389             if  (m_OutBuffers[0].cbBuffer != 0    &&
00390                  m_OutBuffers[0].pvBuffer != NULL )
00391             {
00392 
00393                 //
00394                 // Send response to server if there is one
00395                 //
00396                 int ret = handshakeSend((const char *)m_OutBuffers[0].pvBuffer, m_OutBuffers[0].cbBuffer);
00397                 if (ret != 0)
00398                 {
00399                     NegotiateError();
00400                 }
00401                 return;
00402             }
00403         } else if (m_OutBuffers[0].pvBuffer != NULL) {
00404             // Basically, we shouldn't have an output buffer here unless SCHANNEL
00405             // has indicated that it has returned one.  Oh well, free it anyway;
00406             // just to be sure.
00407             g_SecurityFunc.FreeContextBuffer( m_OutBuffers[0].pvBuffer );
00408             m_OutBuffers[0].pvBuffer = NULL;
00409         }
00410 
00411         ProcessDecryptedData();
00412     }
00413 }
00414 
00415 template <class T>
00416 bool CAsyncSecureSocketServer<T>::ProcessDecryptedData()
00417 {
00418     // whoopy!
00419     if ( m_scRet == SEC_E_OK )
00420     {
00421         // Query the maximum message sizes
00422         SECURITY_STATUS m_scRet = querySizes();
00423         if (m_scRet != SEC_E_OK)
00424         {
00425             SetLastError(m_scRet);
00426             NegotiateError();
00427             return false;
00428         }
00429         
00430         // Delete the old extra data buffer if we created one
00431         // size may have changed?
00432         m_Extra.reset(0);
00433 
00434         // work out how much space to allocate for a new one.  Note - we 
00435         // will not allocate an extra buffer unless necessary.
00436         DWORD dwExtraSize = m_Sizes.cbMaximumMessage+m_Sizes.cbHeader+m_Sizes.cbTrailer;
00437         if ( m_InBuffers[1].BufferType == SECBUFFER_EXTRA )
00438         {
00439             // If not enough space by default, allocate some more...
00440             if (dwExtraSize < m_InBuffers[1].cbBuffer)
00441                 dwExtraSize = m_InBuffers[1].cbBuffer;
00442 
00443             m_ExtraCount = m_InBuffers[1].cbBuffer;
00444             m_Extra.reset(new char[dwExtraSize]);
00445             CopyMemory(m_Extra.get(), (LPBYTE) (m_handshakeIo.get() + 
00446                 (m_cbHandshakeIo - m_InBuffers[1].cbBuffer)), m_ExtraCount);
00447         }
00448 
00449         // all done
00450         ConnectionState PrevState = m_State;
00451         m_State = State_Connected;
00452         switch (PrevState)
00453         {
00454         case State_Negotiate:
00455             m_pCallback->onConnectCompletion(TRUE);
00456             break;
00457         case State_Renegotiate:
00458             ReadCompletion(TRUE, 0);
00459             break;
00460         }
00461         return false;
00462     }
00463 
00464     if (FAILED(m_scRet) && (m_scRet != SEC_E_INCOMPLETE_MESSAGE))
00465     {
00466         // pants, it went wrong
00467         SetLastError(m_scRet);
00468         NegotiateError();
00469         return false;
00470     }
00471 
00472     // If we've potentially got extra data, then process it by copying it
00473     // to our SCHANNEL input buffer
00474     if ( m_scRet != SEC_E_INCOMPLETE_MESSAGE &&
00475          m_scRet != SEC_I_INCOMPLETE_CREDENTIALS)
00476     {
00477 
00478         if ( m_InBuffers[1].BufferType == SECBUFFER_EXTRA )
00479         {
00480             // No need to check the amount of data here as we are already
00481             // fetching it from the input buffer space...
00482             MoveMemory(m_handshakeIo.get(),
00483                    (LPBYTE) (m_handshakeIo.get() + (m_cbHandshakeIo - m_InBuffers[1].cbBuffer)),
00484                     m_InBuffers[1].cbBuffer);
00485             m_cbHandshakeIo = m_InBuffers[1].cbBuffer;
00486         }
00487         else
00488         {
00489             //
00490             // prepare for next receive
00491             //
00492             m_cbHandshakeIo = 0;
00493         }
00494     }
00495     return true;
00496 }
00497 
00498 template <class T>
00499 int CAsyncSecureSocketServer<T>::shutdown(int how)
00500 {
00501     if (m_State != State_Connected) {
00502         return  CAsyncSecureSocket<T>::shutdown(how);
00503     }
00504 
00505     //
00506     // Notify schannel that we are about to close the connection.
00507     //
00508 
00509     DWORD dwType = SCHANNEL_SHUTDOWN;
00510 
00511     SecBuffer OutBuffers[1];
00512     OutBuffers[0].pvBuffer   = &dwType;
00513     OutBuffers[0].BufferType = SECBUFFER_TOKEN;
00514     OutBuffers[0].cbBuffer   = sizeof(dwType);
00515 
00516     SecBufferDesc OutBuffer;
00517     OutBuffer.cBuffers  = 1;
00518     OutBuffer.pBuffers  = OutBuffers;
00519     OutBuffer.ulVersion = SECBUFFER_VERSION;
00520 
00521     SECURITY_STATUS Status = g_SecurityFunc.ApplyControlToken(&m_hContext, &OutBuffer);
00522 
00523     if (FAILED(Status))
00524     {
00525         SetLastError(Status);
00526         return SOCKET_ERROR;
00527     }
00528 
00529     //
00530     // Build an SSL close notify message.
00531     //
00532 
00533     DWORD dwSSPIFlags =   
00534         ASC_REQ_SEQUENCE_DETECT |
00535         ASC_REQ_REPLAY_DETECT   |
00536         ASC_REQ_CONFIDENTIALITY |
00537         ASC_REQ_EXTENDED_ERROR  |
00538         ASC_REQ_ALLOCATE_MEMORY |
00539         ASC_REQ_STREAM;
00540 
00541     OutBuffers[0].pvBuffer   = NULL;
00542     OutBuffers[0].BufferType = SECBUFFER_TOKEN;
00543     OutBuffers[0].cbBuffer   = 0;
00544 
00545     OutBuffer.cBuffers  = 1;
00546     OutBuffer.pBuffers  = OutBuffers;
00547     OutBuffer.ulVersion = SECBUFFER_VERSION;
00548 
00549     DWORD dwSSPIOutFlags = 0;
00550     TimeStamp tsExpiry = {0};
00551     Status = g_SecurityFunc.AcceptSecurityContext(
00552                     &m_hCreds,
00553                     &m_hContext,
00554                     NULL,
00555                     dwSSPIFlags,
00556                     SECURITY_NATIVE_DREP,
00557                     NULL,
00558                     &OutBuffer,
00559                     &dwSSPIOutFlags,
00560                     &tsExpiry);
00561 
00562     if (FAILED(Status))
00563     {
00564         SetLastError(Status);
00565         return SOCKET_ERROR;
00566     }
00567 
00568     PBYTE pbMessage = (unsigned char *)OutBuffers[0].pvBuffer;
00569     DWORD cbMessage = OutBuffers[0].cbBuffer;
00570 
00571     //
00572     // Send the close notify message to the client.
00573     //
00574 
00575     // Nothing to send?
00576     if (pbMessage == NULL || cbMessage == 0)
00577     {
00578         return CAsyncSecureSocket<T>::shutdown(how);
00579     }
00580 
00581     // Send the data
00582     m_State = State_Shutdown;
00583     m_shutdownHow = how;
00584     if (handshakeSend((const char*)pbMessage, (int)cbMessage) != 0)
00585     {
00586         return SOCKET_ERROR;
00587     }
00588     return 0;
00589 }
00590 
00591 template <class T>
00592 SECURITY_STATUS CAsyncSecureSocketServer<T>::setServerCertificate(PCCERT_CONTEXT pCertContext,
00593     DWORD dwEnabledProtocols)
00594 {
00595     SECURITY_STATUS scRet = createCredentialsFromCertificate(&m_hCreds, pCertContext, 
00596         dwEnabledProtocols);
00597     if (SUCCEEDED(scRet)) 
00598         m_ownCredentials = true;
00599     return scRet;
00600 }
00601 
00602 } // namespace OW32
00603 
00604 #endif // OW32_AsyncSecureSocketServer_h

Generated on Sun Jun 5 01:29:17 2005 for OW32 by  doxygen 1.3.9.1