#include "TLSSession.h" #include "EPoll.h" #include "Log.h" #include "Exception.h" namespace core { static int generate_session_id(const SSL *ssl, unsigned char *id, unsigned int *id_len) { char *session_id_prefix = (char *)"BARANT"; unsigned int count = 0; coreutils::Log(coreutils::LOG_DEBUG_3) << "Generating unique session id."; do { RAND_bytes(id, *id_len); memcpy(id, session_id_prefix, (strlen(session_id_prefix) < *id_len)); } while(SSL_has_matching_session_id(ssl, id, *id_len) && (++count < 10)); return 1; } void handshake_complete(const SSL *ssl, int where, int ret) { coreutils::Log(coreutils::LOG_DEBUG_3) << "==>" << SSL_state_string_long(ssl) << "<==" << ret; if(where & SSL_CB_HANDSHAKE_DONE) { X509 *ssl_client_cert = SSL_get_peer_certificate(ssl); if(!ssl_client_cert) throw std::string("Unable to get peer certificate."); X509_free(ssl_client_cert); if(SSL_get_verify_result(ssl) != X509_V_OK) throw std::string("Certificate verification failed."); coreutils::Log(coreutils::LOG_DEBUG_3) << "Certificate verified successfully."; } else coreutils::Log(coreutils::LOG_DEBUG_3) << "No client certificate."; } TLSSession::TLSSession(EPoll &ePoll, TCPServer &server) : TCPSession(ePoll, server) {} void TLSSession::onRegister() { initialized = true; int ret; coreutils::Log(coreutils::LOG_DEBUG_3) << "TLS socket initializing on socket " << getDescriptor() << "..."; fcntl(getDescriptor(), F_SETFL, fcntl(getDescriptor(), F_GETFL, 0) | O_NONBLOCK); ssl = SSL_new(static_cast(server).ctx); // if(ssl <= 0) // throw std::string("Error creating new TLS socket."); SSL_set_info_callback(ssl, handshake_complete); if((ret = SSL_set_fd(ssl, getDescriptor())) == 0) throw std::string("Error setting TLS socket descriptor."); // if(!SSL_set_generate_session_id(ssl, generate_session_id)) // throw std::string("Error setting session identifier callback."); } void TLSSession::onRegistered() { switch (SSL_get_error(ssl, SSL_accept(ssl))) { case SSL_ERROR_SSL: coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_SSL on ssl_accept. errno=" << errno; break; case SSL_ERROR_WANT_READ: coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_WANT_READ on ssl_accept."; break; case SSL_ERROR_WANT_WRITE: coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_WANT_WRITE on ssl_accept."; break; case SSL_ERROR_SYSCALL: coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_SYSCALL on ssl_accept. errno=" << errno; shutdown(); break; default: coreutils::Log(coreutils::LOG_DEBUG_3) << "Unknown ERROR on ssl_accept."; break; } } TLSSession::~TLSSession() {} void TLSSession::protocol(coreutils::ZString &data) {} void TLSSession::receiveData(coreutils::ZString &buffer) { int len; // int error = -1; // std::cout << "receiveData TLS" << std::endl; if((len = ::SSL_read(ssl, buffer.getData(), buffer.getLength())) >= 0) { std::cout << "receiveData TLS...len=" << len << ":" << buffer << std::endl; onDataReceived(buffer); } else { switch (SSL_get_error(ssl, len)) { case SSL_ERROR_SSL: coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_SSL on ssl_read. error=" << errno; break; case SSL_ERROR_WANT_READ: coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_WANT_READ on ssl_read."; break; case SSL_ERROR_WANT_WRITE: coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_WANT_WRITE on ssl_read."; break; case SSL_ERROR_SYSCALL: coreutils::Log(coreutils::LOG_DEBUG_3) << "ERROR_SYSCALL on ssl_read. errno=" << errno; break; default: coreutils::Log(coreutils::LOG_DEBUG_3) << "Unknown ERROR on ssl_read."; break; } } } void TLSSession::output(std::stringstream &out) { out << "|" << ipAddress.getClientAddressAndPort(); } }