127 lines
4.1 KiB
C++
127 lines
4.1 KiB
C++
#include "TLSSession.h"
|
|
#include "EPoll.h"
|
|
#include "Log.h"
|
|
#include "Exception.h"
|
|
|
|
namespace core {
|
|
|
|
static int generate_session_id(const SSL *ssl, unsigned char *id, unsigned int *id_len) {
|
|
char *session_id_prefix = (char *)"BARANT";
|
|
unsigned int count = 0;
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "Generating unique session id.";
|
|
do {
|
|
RAND_bytes(id, *id_len);
|
|
memcpy(id, session_id_prefix, (strlen(session_id_prefix) < *id_len));
|
|
} while(SSL_has_matching_session_id(ssl, id, *id_len) && (++count < 10));
|
|
return 1;
|
|
}
|
|
|
|
void handshake_complete(const SSL *ssl, int where, int ret) {
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "==>" << SSL_state_string_long(ssl) << "<==" << ret;
|
|
if(where & SSL_CB_HANDSHAKE_DONE) {
|
|
X509 *ssl_client_cert = SSL_get_peer_certificate(ssl);
|
|
if(!ssl_client_cert)
|
|
throw std::string("Unable to get peer certificate.");
|
|
X509_free(ssl_client_cert);
|
|
if(SSL_get_verify_result(ssl) != X509_V_OK)
|
|
throw std::string("Certificate verification failed.");
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "Certificate verified successfully.";
|
|
}
|
|
else
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "No client certificate.";
|
|
}
|
|
|
|
TLSSession::TLSSession(EPoll &ePoll, TCPServer &server) : TCPSession(ePoll, server) {}
|
|
|
|
void TLSSession::onRegister() {
|
|
initialized = true;
|
|
int ret;
|
|
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "TLS socket initializing on socket " << getDescriptor() << "...";
|
|
|
|
fcntl(getDescriptor(), F_SETFL, fcntl(getDescriptor(), F_GETFL, 0) | O_NONBLOCK);
|
|
|
|
ssl = SSL_new(static_cast<TLSServer &>(server).ctx);
|
|
if(ssl <= 0)
|
|
throw std::string("Error creating new TLS socket.");
|
|
|
|
SSL_set_info_callback(ssl, handshake_complete);
|
|
|
|
if((ret = SSL_set_fd(ssl, getDescriptor())) == 0)
|
|
throw std::string("Error setting TLS socket descriptor.");
|
|
|
|
if(!SSL_set_generate_session_id(ssl, generate_session_id))
|
|
throw std::string("Error setting session identifier callback.");
|
|
|
|
}
|
|
|
|
void TLSSession::onRegistered() {
|
|
|
|
switch (SSL_get_error(ssl, SSL_accept(ssl))) {
|
|
case SSL_ERROR_SSL:
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_SSL on ssl_accept. errno=" << errno;
|
|
break;
|
|
case SSL_ERROR_WANT_READ:
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_WANT_READ on ssl_accept.";
|
|
break;
|
|
case SSL_ERROR_WANT_WRITE:
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_WANT_WRITE on ssl_accept.";
|
|
break;
|
|
case SSL_ERROR_SYSCALL:
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_SYSCALL on ssl_accept. errno=" << errno;
|
|
shutdown();
|
|
break;
|
|
default:
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "Unknown ERROR on ssl_accept.";
|
|
break;
|
|
}
|
|
|
|
}
|
|
|
|
TLSSession::~TLSSession() {
|
|
|
|
}
|
|
|
|
void TLSSession::protocol(std::stringstream &out, std::string data) {
|
|
|
|
}
|
|
|
|
void TLSSession::receiveData(char *buffer, int bufferLength) {
|
|
|
|
int len;
|
|
// int error = -1;
|
|
//
|
|
std::cout << "receiveData TLS" << std::endl;
|
|
|
|
if((len = ::SSL_read(ssl, buffer, bufferLength)) >= 0) {
|
|
std::cout << "receiveData TLS...len=" << len << ":" << buffer << std::endl;
|
|
onDataReceived(std::string(buffer, len));
|
|
}
|
|
else {
|
|
switch (SSL_get_error(ssl, len)) {
|
|
case SSL_ERROR_SSL:
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_SSL on ssl_read. error=" << errno;
|
|
break;
|
|
case SSL_ERROR_WANT_READ:
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_WANT_READ on ssl_read.";
|
|
break;
|
|
case SSL_ERROR_WANT_WRITE:
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_WANT_WRITE on ssl_read.";
|
|
break;
|
|
case SSL_ERROR_SYSCALL:
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_SYSCALL on ssl_read. errno=" << errno;
|
|
break;
|
|
default:
|
|
coreutils::Log(coreutils::LOG_DEBUG_3) << "Unknown ERROR on ssl_read.";
|
|
break;
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
void TLSSession::output(std::stringstream &out) {
|
|
out << "|" << ipAddress.getClientAddressAndPort();
|
|
}
|
|
|
|
}
|
|
|