00001
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #include "KjSslConnection.h"
00035
00036 KjSslConnection::KjSslConnection(KjSslConnection::EKjSslConType _eType, char* _strConn,const string &_strCertFile, const string &_strKeyFile): m_eType(_eType)
00037 {
00038 if(!Init(_strCertFile, _strKeyFile))
00039 {
00040 m_bIsValid = false;
00041 return;
00042 }
00043 if(m_eType==KJ_SSL_CLIENT)
00044 {
00045 m_bIsValid = ClientConnect(_strConn);
00046 }
00047 else if (m_eType==KJ_SSL_SERVER)
00048 {
00049 m_bIsValid = ServerConnect(_strConn);
00050 }
00051 }
00052
00053 KjSslConnection::KjSslConnection(BIO* _pBio,const string &_strCertFile, const string &_strKeyFile):
00054 m_eType(KjSslConnection::KJ_SSL_SERVER),m_pBio(_pBio)
00055 {
00056 m_iId = -1;
00057 if(!Init(_strCertFile, _strKeyFile))
00058 {
00059 m_bIsValid = false;
00060 return;
00061 }
00062 m_bIsValid = ServerConnect("");
00063 }
00064
00065 KjSslConnection::~KjSslConnection()
00066 {
00067 SSL_free(m_pSsl);
00068 m_bIsValid = false;
00069 }
00070
00075 bool KjSslConnection::Init(const string& _strCertFile, const string& _strKeyFile)
00076 {
00077 m_pSsl=SSL_new(GetCtx());
00078 m_bIsValid=(bool)m_pSsl;
00079 if(!m_bIsValid)
00080 {
00081 return false;
00082 }
00083 if(!_strCertFile.empty())
00084 {
00085 if(SSL_use_certificate_file(m_pSsl,_strCertFile.c_str(),SSL_FILETYPE_PEM)<=0)
00086 {
00087 CheckErrors("SSL_use_certificate_file");
00088 m_bIsValid=false;
00089 return false;
00090 }
00091 }
00092 if(!_strKeyFile.empty())
00093 {
00094 if(SSL_use_RSAPrivateKey_file(m_pSsl,_strKeyFile.c_str(),SSL_FILETYPE_PEM)<=0)
00095 {
00096 CheckErrors("SSL_use_RSAPrivateKey_file");
00097 m_bIsValid=false;
00098 return false;
00099 }
00100 }
00101 if(!_strKeyFile.empty()&&!_strCertFile.empty())
00102 {
00103 if(SSL_check_private_key(m_pSsl)<=0)
00104 {
00105 CheckErrors("SSL_check_private_key");
00106 m_bIsValid=false;
00107 return false;
00108 }
00109 }
00110 if (SSL_set_cipher_list(m_pSsl,"ALL" )<=0)
00111 {
00112 CheckErrors("SSL_set_cipher_list");
00113 m_bIsValid=false;
00114 return false;
00115 }
00116 return true;
00117 }
00118
00122 bool KjSslConnection::ServerConnect(char* _strConn)
00123 {
00124 if(!m_pBio)
00125 {
00126 m_pBio=BIO_new_accept(_strConn);
00127 if(!m_pBio)
00128 {
00129 return false;
00130 }
00131 BIO_set_bind_mode(m_pBio,BIO_BIND_REUSEADDR);
00132 }
00133 SSL_set_bio(m_pSsl, m_pBio, m_pBio);
00134 SSL_set_accept_state(m_pSsl);
00135 int iRet = SSL_accept(m_pSsl);
00136 if(iRet<=0)
00137 {
00138 CheckErrors("SSL_accept");
00139 CheckSslErrors(m_pSsl, iRet, "SSL_accept");
00140 return false;
00141 }
00142 return true;
00143 }
00144
00148 bool KjSslConnection::ClientConnect(char* _strConn)
00149 {
00150 m_pBio=BIO_new_connect(_strConn);
00151 if(BIO_do_connect(m_pBio)<=0)
00152 {
00153 m_bIsValid=false;
00154 return false;
00155 }
00156 SSL_set_bio(m_pSsl,m_pBio,m_pBio);
00157 SSL_set_connect_state(m_pSsl);
00158
00159 cout << "SSL client connected..." << endl;
00160 int iRet = SSL_connect(m_pSsl);
00161 if(iRet<=0)
00162 {
00163 CheckErrors("SSL_connect");
00164 CheckSslErrors(m_pSsl, iRet, "SSL_connect");
00165 return false;
00166 }
00167 return true;
00168 }
00169
00173 int KjSslConnection::Write(const string& _str)
00174 {
00175 int iRet = 0;
00176 if(CheckSslState())
00177 {
00178 int iStrSize = _str.length();
00179 if(iStrSize > KJ_SSL_READ_CHUNK_SIZE)
00180 {
00181
00182 string strPacket = _str;
00183 for(;;)
00184 {
00185 if(strPacket.length() <= 0)
00186 {
00187 break;
00188 }
00189
00190 string strEnd = strPacket;
00191 if(strPacket.size() > KJ_SSL_READ_CHUNK_SIZE)
00192 {
00193 strEnd.erase(0, KJ_SSL_READ_CHUNK_SIZE);
00194 strPacket.erase(strPacket.begin()+KJ_SSL_READ_CHUNK_SIZE, strPacket.end());
00195 }
00196 else
00197 {
00198 strEnd.erase(0, strPacket.size());
00199 }
00200
00201 iRet = SSL_write(m_pSsl,strPacket.c_str(),strPacket.length());
00202 strPacket = strEnd;
00203 }
00204 }
00205 else
00206 {
00207
00208 iRet = SSL_write(m_pSsl,_str.c_str(),_str.length());
00209 }
00210
00211 string strEnd = KJ_END_MSG;
00212 iRet = SSL_write(m_pSsl,strEnd.c_str(),strEnd.length());
00213 }
00214 else
00215 {
00216 cerr << "Error :: int KjSslConnection::Write(const string& _str) -- SSL state is incorrect...!" << endl;
00217 }
00218 return iRet;
00219 }
00220
00224 int KjSslConnection::Read(string& _str)
00225 {
00226 int iRet = 0;
00227 if(CheckSslState())
00228 {
00229 char str[KJ_SSL_READ_CHUNK_SIZE];
00230 cout << "Waiting for some message ... " << endl;
00231 iRet = SSL_read(m_pSsl,str, KJ_SSL_READ_CHUNK_SIZE);
00232 str[iRet]=0;
00233 _str="";
00234 _str = str;
00235 }
00236 else
00237 {
00238 cerr << "Error :: int KjSslConnection::Read(string& _str) -- SSL state is incorrect...!" << endl;
00239 }
00240 return iRet;
00241 }
00242
00245 bool KjSslConnection::CheckSslState()
00246 {
00247 if(SSL_get_state(m_pSsl) != SSL_ST_OK)
00248 {
00249 int iRet = SSL_accept(m_pSsl);
00250 if(iRet <= 0)
00251 {
00252 unsigned long e;
00253 char *buf;
00254
00255 if((SSL_get_error(m_pSsl, iRet) == SSL_ERROR_WANT_READ) ||
00256 SSL_get_error(m_pSsl, iRet) == SSL_ERROR_WANT_WRITE)
00257 {
00258 cerr << "Error :: bool KjSslConnection::CheckSslState() -- Read blocked, returning" << endl;
00259 return false;
00260 }
00261 e = ERR_get_error();
00262 buf = ERR_error_string(e, NULL);
00263 cerr << "Error :: bool KjSslConnection::CheckSslState() -- in SSL_accept call " << buf << endl;
00264 return false;
00265 }
00266 }
00267 return true;
00268 }