ServerCore/TLSSession.cpp

128 lines
3.8 KiB
C++

#include "TLSSession.h"
#include "TLSService.h"
#include "EPoll.h"
#include "Log.h"
#include "Exception.h"
//#include <openssl/rand.h>
namespace core {
static int generate_session_id(const SSL *ssl, unsigned char *id, unsigned int *id_len) {
char *session_id_prefix = (char *)"BARANT";
unsigned int count = 0;
Log(LOG_DEBUG_3) << "Generating unique session id.";
do {
RAND_bytes(id, *id_len);
memcpy(id, session_id_prefix, (strlen(session_id_prefix) < *id_len));
} while(SSL_has_matching_session_id(ssl, id, *id_len) && (++count < 10));
return 1;
}
void handshake_complete(const SSL *ssl, int where, int ret) {
Log(LOG_DEBUG_3) << "==>" << SSL_state_string_long(ssl) << "<==";
if(where & SSL_CB_HANDSHAKE_DONE) {
X509 *ssl_client_cert = SSL_get_peer_certificate(ssl);
if(!ssl_client_cert)
throw std::string("Unable to get peer certificate.");
X509_free(ssl_client_cert);
if(SSL_get_verify_result(ssl) != X509_V_OK)
throw std::string("Certificate verification failed.");
Log(LOG_DEBUG_3) << "Certificate verified successfully.";
}
else
Log(LOG_DEBUG_3) << "No client certificate.";
}
TLSSession::TLSSession(EPoll &ePoll, Service &service) : Session(ePoll, service) {}
void TLSSession::init() {
initialized = true;
int ret;
Log(LOG_DEBUG_3) << "TLS socket initializing...";
fcntl(getDescriptor(), F_SETFL, fcntl(getDescriptor(), F_GETFL, 0) | O_NONBLOCK);
if(!(ssl = SSL_new(((TLSService &)service).ctx)))
throw std::string("Error creating new TLS socket.");
SSL_set_info_callback(ssl, handshake_complete);
if((ret = SSL_set_fd(ssl, getDescriptor())) == 0)
throw std::string("Error setting TLS socket descriptor.");
if(!SSL_set_generate_session_id(ssl, generate_session_id))
throw std::string("Error setting session identifier callback.");
switch (SSL_get_error(ssl, SSL_accept(ssl))) {
case SSL_ERROR_SSL:
Log(LOG_DEBUG_3) << "ERROR_SSL on ssl_accept. errno=" << errno;
break;
case SSL_ERROR_WANT_READ:
Log(LOG_DEBUG_3) << "ERROR_WANT_READ on ssl_accept.";
break;
case SSL_ERROR_WANT_WRITE:
Log(LOG_DEBUG_3) << "ERROR_WANT_WRITE on ssl_accept.";
break;
case SSL_ERROR_SYSCALL:
Log(LOG_DEBUG_3) << "ERROR_SYSCALL on ssl_accept. errno=" << errno;
shutdown();
break;
default:
Log(LOG_DEBUG_3) << "Unknown ERROR on ssl_accept.";
break;
}
}
TLSSession::~TLSSession() {
}
void TLSSession::protocol(std::string data) {
}
void TLSSession::receiveData(char *buffer, int bufferLength) {
if(!initialized)
init();
int len;
// int error = -1;
//
std::cout << "receiveData TLS" << std::endl;
if((len = ::SSL_read(ssl, buffer, bufferLength)) >= 0) {
std::cout << "receiveData TLS...len=" << len << ":" << buffer << std::endl;
onDataReceived(std::string(buffer, len));
}
else {
switch (SSL_get_error(ssl, len)) {
case SSL_ERROR_SSL:
Log(LOG_DEBUG_3) << "ERROR_SSL on ssl_read. error=" << errno;
break;
case SSL_ERROR_WANT_READ:
Log(LOG_DEBUG_3) << "ERROR_WANT_READ on ssl_read.";
break;
case SSL_ERROR_WANT_WRITE:
Log(LOG_DEBUG_3) << "ERROR_WANT_WRITE on ssl_read.";
break;
case SSL_ERROR_SYSCALL:
Log(LOG_DEBUG_3) << "ERROR_SYSCALL on ssl_read. errno=" << errno;
break;
default:
Log(LOG_DEBUG_3) << "Unknown ERROR on ssl_read.";
break;
}
}
}
void TLSSession::output(std::stringstream &out) {
out << "|" << ipAddress.getClientAddressAndPort();
}
}