0
0
mirror of https://github.com/mongodb/mongo.git synced 2024-12-01 09:32:32 +01:00
mongodb/util/net/sock.cpp
2011-07-27 09:48:35 -04:00

677 lines
22 KiB
C++

// @file sock.cpp
/* Copyright 2009 10gen Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pch.h"
#include "sock.h"
#include "../background.h"
#if !defined(_WIN32)
# include <sys/socket.h>
# include <sys/types.h>
# include <sys/socket.h>
# include <sys/un.h>
# include <netinet/in.h>
# include <netinet/tcp.h>
# include <arpa/inet.h>
# include <errno.h>
# include <netdb.h>
# if defined(__openbsd__)
# include <sys/uio.h>
# endif
#endif
#ifdef MONGO_SSL
#include <openssl/err.h>
#include <openssl/ssl.h>
#endif
namespace mongo {
static bool ipv6 = false;
void enableIPv6(bool state) { ipv6 = state; }
bool IPv6Enabled() { return ipv6; }
// --- some global helpers -----
#if defined(_WIN32)
void disableNagle(int sock) {
int x = 1;
if ( setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char *) &x, sizeof(x)) )
out() << "ERROR: disableNagle failed" << endl;
if ( setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (char *) &x, sizeof(x)) )
out() << "ERROR: SO_KEEPALIVE failed" << endl;
}
#else
void disableNagle(int sock) {
int x = 1;
#ifdef SOL_TCP
int level = SOL_TCP;
#else
int level = SOL_SOCKET;
#endif
if ( setsockopt(sock, level, TCP_NODELAY, (char *) &x, sizeof(x)) )
log() << "ERROR: disableNagle failed: " << errnoWithDescription() << endl;
#ifdef SO_KEEPALIVE
if ( setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (char *) &x, sizeof(x)) )
log() << "ERROR: SO_KEEPALIVE failed: " << errnoWithDescription() << endl;
#endif
}
#endif
string getAddrInfoStrError(int code) {
#if !defined(_WIN32)
return gai_strerror(code);
#else
/* gai_strerrorA is not threadsafe on windows. don't use it. */
return errnoWithDescription(code);
#endif
}
// --- SockAddr
SockAddr::SockAddr(int sourcePort) {
memset(as<sockaddr_in>().sin_zero, 0, sizeof(as<sockaddr_in>().sin_zero));
as<sockaddr_in>().sin_family = AF_INET;
as<sockaddr_in>().sin_port = htons(sourcePort);
as<sockaddr_in>().sin_addr.s_addr = htonl(INADDR_ANY);
addressSize = sizeof(sockaddr_in);
}
SockAddr::SockAddr(const char * iporhost , int port) {
if (!strcmp(iporhost, "localhost"))
iporhost = "127.0.0.1";
if (strchr(iporhost, '/')) {
#ifdef _WIN32
uassert(13080, "no unix socket support on windows", false);
#endif
uassert(13079, "path to unix socket too long", strlen(iporhost) < sizeof(as<sockaddr_un>().sun_path));
as<sockaddr_un>().sun_family = AF_UNIX;
strcpy(as<sockaddr_un>().sun_path, iporhost);
addressSize = sizeof(sockaddr_un);
}
else {
addrinfo* addrs = NULL;
addrinfo hints;
memset(&hints, 0, sizeof(addrinfo));
hints.ai_socktype = SOCK_STREAM;
//hints.ai_flags = AI_ADDRCONFIG; // This is often recommended but don't do it. SERVER-1579
hints.ai_flags |= AI_NUMERICHOST; // first pass tries w/o DNS lookup
hints.ai_family = (IPv6Enabled() ? AF_UNSPEC : AF_INET);
StringBuilder ss;
ss << port;
int ret = getaddrinfo(iporhost, ss.str().c_str(), &hints, &addrs);
// old C compilers on IPv6-capable hosts return EAI_NODATA error
#ifdef EAI_NODATA
int nodata = (ret == EAI_NODATA);
#else
int nodata = false;
#endif
if (ret == EAI_NONAME || nodata) {
// iporhost isn't an IP address, allow DNS lookup
hints.ai_flags &= ~AI_NUMERICHOST;
ret = getaddrinfo(iporhost, ss.str().c_str(), &hints, &addrs);
}
if (ret) {
// don't log if this as it is a CRT construction and log() may not work yet.
if( strcmp("0.0.0.0", iporhost) ) {
log() << "getaddrinfo(\"" << iporhost << "\") failed: " << gai_strerror(ret) << endl;
}
*this = SockAddr(port);
}
else {
//TODO: handle other addresses in linked list;
assert(addrs->ai_addrlen <= sizeof(sa));
memcpy(&sa, addrs->ai_addr, addrs->ai_addrlen);
addressSize = addrs->ai_addrlen;
freeaddrinfo(addrs);
}
}
}
bool SockAddr::isLocalHost() const {
switch (getType()) {
case AF_INET: return getAddr() == "127.0.0.1";
case AF_INET6: return getAddr() == "::1";
case AF_UNIX: return true;
default: return false;
}
assert(false);
return false;
}
string SockAddr::toString(bool includePort) const {
string out = getAddr();
if (includePort && getType() != AF_UNIX && getType() != AF_UNSPEC)
out += mongoutils::str::stream() << ':' << getPort();
return out;
}
sa_family_t SockAddr::getType() const {
return sa.ss_family;
}
unsigned SockAddr::getPort() const {
switch (getType()) {
case AF_INET: return ntohs(as<sockaddr_in>().sin_port);
case AF_INET6: return ntohs(as<sockaddr_in6>().sin6_port);
case AF_UNIX: return 0;
case AF_UNSPEC: return 0;
default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); return 0;
}
}
string SockAddr::getAddr() const {
switch (getType()) {
case AF_INET:
case AF_INET6: {
const int buflen=128;
char buffer[buflen];
int ret = getnameinfo(raw(), addressSize, buffer, buflen, NULL, 0, NI_NUMERICHOST);
massert(13082, getAddrInfoStrError(ret), ret == 0);
return buffer;
}
case AF_UNIX: return (addressSize > 2 ? as<sockaddr_un>().sun_path : "anonymous unix socket");
case AF_UNSPEC: return "(NONE)";
default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); return "";
}
}
bool SockAddr::operator==(const SockAddr& r) const {
if (getType() != r.getType())
return false;
if (getPort() != r.getPort())
return false;
switch (getType()) {
case AF_INET: return as<sockaddr_in>().sin_addr.s_addr == r.as<sockaddr_in>().sin_addr.s_addr;
case AF_INET6: return memcmp(as<sockaddr_in6>().sin6_addr.s6_addr, r.as<sockaddr_in6>().sin6_addr.s6_addr, sizeof(in6_addr)) == 0;
case AF_UNIX: return strcmp(as<sockaddr_un>().sun_path, r.as<sockaddr_un>().sun_path) == 0;
case AF_UNSPEC: return true; // assume all unspecified addresses are the same
default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false);
}
return false;
}
bool SockAddr::operator!=(const SockAddr& r) const {
return !(*this == r);
}
bool SockAddr::operator<(const SockAddr& r) const {
if (getType() < r.getType())
return true;
else if (getType() > r.getType())
return false;
if (getPort() < r.getPort())
return true;
else if (getPort() > r.getPort())
return false;
switch (getType()) {
case AF_INET: return as<sockaddr_in>().sin_addr.s_addr < r.as<sockaddr_in>().sin_addr.s_addr;
case AF_INET6: return memcmp(as<sockaddr_in6>().sin6_addr.s6_addr, r.as<sockaddr_in6>().sin6_addr.s6_addr, sizeof(in6_addr)) < 0;
case AF_UNIX: return strcmp(as<sockaddr_un>().sun_path, r.as<sockaddr_un>().sun_path) < 0;
case AF_UNSPEC: return false;
default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false);
}
return false;
}
SockAddr unknownAddress( "0.0.0.0", 0 );
// ------ hostname -------------------
string hostbyname(const char *hostname) {
string addr = SockAddr(hostname, 0).getAddr();
if (addr == "0.0.0.0")
return "";
else
return addr;
}
// --- my --
string getHostName() {
char buf[256];
int ec = gethostname(buf, 127);
if ( ec || *buf == 0 ) {
log() << "can't get this server's hostname " << errnoWithDescription() << endl;
return "";
}
return buf;
}
string _hostNameCached;
static void _hostNameCachedInit() {
_hostNameCached = getHostName();
}
boost::once_flag _hostNameCachedInitFlags = BOOST_ONCE_INIT;
string getHostNameCached() {
boost::call_once( _hostNameCachedInit , _hostNameCachedInitFlags );
return _hostNameCached;
}
// --------- SocketException ----------
#ifdef MSG_NOSIGNAL
const int portSendFlags = MSG_NOSIGNAL;
const int portRecvFlags = MSG_NOSIGNAL;
#else
const int portSendFlags = 0;
const int portRecvFlags = 0;
#endif
string SocketException::toString() const {
stringstream ss;
ss << _ei.code << " socket exception [" << _type << "] ";
if ( _server.size() )
ss << "server [" << _server << "] ";
if ( _extra.size() )
ss << _extra;
return ss.str();
}
// ------------ SSLManager -----------------
#ifdef MONGO_SSL
SSLManager::SSLManager( bool client ) {
_client = client;
SSL_library_init();
SSL_load_error_strings();
ERR_load_crypto_strings();
_context = SSL_CTX_new( client ? SSLv23_client_method() : SSLv23_server_method() );
massert( 15864 , mongoutils::str::stream() << "can't create SSL Context: " << ERR_error_string(ERR_get_error(), NULL) , _context );
SSL_CTX_set_options( _context, SSL_OP_ALL);
}
void SSLManager::setupPubPriv( const string& privateKeyFile , const string& publicKeyFile ) {
massert( 15865 ,
mongoutils::str::stream() << "Can't read SSL certificate from file "
<< publicKeyFile << ":" << ERR_error_string(ERR_get_error(), NULL) ,
SSL_CTX_use_certificate_file(_context, publicKeyFile.c_str(), SSL_FILETYPE_PEM) );
massert( 15866 ,
mongoutils::str::stream() << "Can't read SSL private key from file "
<< privateKeyFile << " : " << ERR_error_string(ERR_get_error(), NULL) ,
SSL_CTX_use_PrivateKey_file(_context, privateKeyFile.c_str(), SSL_FILETYPE_PEM) );
}
int SSLManager::password_cb(char *buf,int num, int rwflag,void *userdata){
SSLManager* sm = (SSLManager*)userdata;
string pass = sm->_password;
strcpy(buf,pass.c_str());
return(pass.size());
}
void SSLManager::setupPEM( const string& keyFile , const string& password ) {
_password = password;
massert( 15867 , "Can't read certificate file" , SSL_CTX_use_certificate_chain_file( _context , keyFile.c_str() ) );
SSL_CTX_set_default_passwd_cb_userdata( _context , this );
SSL_CTX_set_default_passwd_cb( _context, &SSLManager::password_cb );
massert( 15868 , "Can't read key file" , SSL_CTX_use_PrivateKey_file( _context , keyFile.c_str() , SSL_FILETYPE_PEM ) );
}
SSL * SSLManager::secure( int fd ) {
SSL * ssl = SSL_new( _context );
massert( 15861 , "can't create SSL" , ssl );
SSL_set_fd( ssl , fd );
return ssl;
}
#endif
// ------------ Socket -----------------
Socket::Socket(int fd , const SockAddr& remote) :
_fd(fd), _remote(remote), _timeout(0) {
_logLevel = 0;
_init();
}
Socket::Socket( double timeout, int ll ) {
_logLevel = ll;
_fd = -1;
_timeout = timeout;
_init();
}
void Socket::_init() {
_bytesOut = 0;
_bytesIn = 0;
#ifdef MONGO_SSL
_sslAccepted = 0;
#endif
}
void Socket::close() {
#ifdef MONGO_SSL
_ssl.reset();
#endif
if ( _fd >= 0 ) {
closesocket( _fd );
_fd = -1;
}
}
#ifdef MONGO_SSL
void Socket::secure( SSLManager * ssl ) {
assert( ssl );
assert( _fd >= 0 );
_ssl.reset( ssl->secure( _fd ) );
SSL_connect( _ssl.get() );
}
void Socket::secureAccepted( SSLManager * ssl ) {
_sslAccepted = ssl;
}
#endif
void Socket::postFork() {
#ifdef MONGO_SSL
if ( _sslAccepted ) {
assert( _fd );
_ssl.reset( _sslAccepted->secure( _fd ) );
SSL_accept( _ssl.get() );
_sslAccepted = 0;
}
#endif
}
class ConnectBG : public BackgroundJob {
public:
ConnectBG(int sock, SockAddr remote) : _sock(sock), _remote(remote) { }
void run() { _res = ::connect(_sock, _remote.raw(), _remote.addressSize); }
string name() const { return "ConnectBG"; }
int inError() const { return _res; }
private:
int _sock;
int _res;
SockAddr _remote;
};
bool Socket::connect(SockAddr& remote) {
_remote = remote;
_fd = socket(remote.getType(), SOCK_STREAM, 0);
if ( _fd == INVALID_SOCKET ) {
log(_logLevel) << "ERROR: connect invalid socket " << errnoWithDescription() << endl;
return false;
}
if ( _timeout > 0 ) {
setTimeout( _timeout );
}
ConnectBG bg(_fd, remote);
bg.go();
if ( bg.wait(5000) ) {
if ( bg.inError() ) {
close();
return false;
}
}
else {
// time out the connect
close();
bg.wait(); // so bg stays in scope until bg thread terminates
return false;
}
if (remote.getType() != AF_UNIX)
disableNagle(_fd);
#ifdef SO_NOSIGPIPE
// osx
const int one = 1;
setsockopt( _fd , SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(int));
#endif
return true;
}
int Socket::_send( const char * data , int len ) {
#ifdef MONGO_SSL
if ( _ssl ) {
return SSL_write( _ssl.get() , data , len );
}
#endif
return ::send( _fd , data , len , portSendFlags );
}
// sends all data or throws an exception
void Socket::send( const char * data , int len, const char *context ) {
while( len > 0 ) {
int ret = _send( data , len );
if ( ret == -1 ) {
#ifdef MONGO_SSL
if ( _ssl ) {
log() << "SSL Error ret: " << ret << " err: " << SSL_get_error( _ssl.get() , ret )
<< " " << ERR_error_string(ERR_get_error(), NULL)
<< endl;
}
#endif
#if defined(_WIN32)
if ( WSAGetLastError() == WSAETIMEDOUT && _timeout != 0 ) {
#else
if ( ( errno == EAGAIN || errno == EWOULDBLOCK ) && _timeout != 0 ) {
#endif
log(_logLevel) << "Socket " << context << " send() timed out " << _remote.toString() << endl;
throw SocketException( SocketException::SEND_TIMEOUT , remoteString() );
}
else {
SocketException::Type t = SocketException::SEND_ERROR;
log(_logLevel) << "Socket " << context << " send() "
<< errnoWithDescription() << ' ' << remoteString() << endl;
throw SocketException( t , remoteString() );
}
}
else {
_bytesOut += ret;
assert( ret <= len );
len -= ret;
data += ret;
}
}
}
void Socket::_send( const vector< pair< char *, int > > &data, const char *context ) {
for( vector< pair< char *, int > >::const_iterator i = data.begin(); i != data.end(); ++i ) {
char * data = i->first;
int len = i->second;
send( data, len, context );
}
}
// sends all data or throws an exception
void Socket::send( const vector< pair< char *, int > > &data, const char *context ) {
#ifdef MONGO_SSL
if ( _ssl ) {
_send( data , context );
return;
}
#endif
#if defined(_WIN32)
// TODO use scatter/gather api
_send( data , context );
#else
vector< struct iovec > d( data.size() );
int i = 0;
for( vector< pair< char *, int > >::const_iterator j = data.begin(); j != data.end(); ++j ) {
if ( j->second > 0 ) {
d[ i ].iov_base = j->first;
d[ i ].iov_len = j->second;
++i;
_bytesOut += j->second;
}
}
struct msghdr meta;
memset( &meta, 0, sizeof( meta ) );
meta.msg_iov = &d[ 0 ];
meta.msg_iovlen = d.size();
while( meta.msg_iovlen > 0 ) {
int ret = ::sendmsg( _fd , &meta , portSendFlags );
if ( ret == -1 ) {
if ( errno != EAGAIN || _timeout == 0 ) {
log(_logLevel) << "Socket " << context << " send() " << errnoWithDescription() << ' ' << remoteString() << endl;
throw SocketException( SocketException::SEND_ERROR , remoteString() );
}
else {
log(_logLevel) << "Socket " << context << " send() remote timeout " << remoteString() << endl;
throw SocketException( SocketException::SEND_TIMEOUT , remoteString() );
}
}
else {
struct iovec *& i = meta.msg_iov;
while( ret > 0 ) {
if ( i->iov_len > unsigned( ret ) ) {
i->iov_len -= ret;
i->iov_base = (char*)(i->iov_base) + ret;
ret = 0;
}
else {
ret -= i->iov_len;
++i;
--(meta.msg_iovlen);
}
}
}
}
#endif
}
void Socket::recv( char * buf , int len ) {
unsigned retries = 0;
while( len > 0 ) {
int ret = unsafe_recv( buf , len );
if ( ret > 0 ) {
if ( len <= 4 && ret != len )
log(_logLevel) << "Socket recv() got " << ret << " bytes wanted len=" << len << endl;
assert( ret <= len );
len -= ret;
buf += ret;
}
else if ( ret == 0 ) {
log(3) << "Socket recv() conn closed? " << remoteString() << endl;
throw SocketException( SocketException::CLOSED , remoteString() );
}
else { /* ret < 0 */
int e = errno;
#if defined(EINTR) && !defined(_WIN32)
if( e == EINTR ) {
if( ++retries == 1 ) {
log() << "EINTR retry" << endl;
continue;
}
}
#endif
if ( ( e == EAGAIN
#if defined(_WIN32)
|| e == WSAETIMEDOUT
#endif
) && _timeout > 0 ) {
// this is a timeout
log(_logLevel) << "Socket recv() timeout " << remoteString() <<endl;
throw SocketException( SocketException::RECV_TIMEOUT, remoteString() );
}
log(_logLevel) << "Socket recv() " << errnoWithDescription(e) << " " << remoteString() <<endl;
throw SocketException( SocketException::RECV_ERROR , remoteString() );
}
}
}
int Socket::unsafe_recv( char *buf, int max ) {
int x = _recv( buf , max );
_bytesIn += x;
return x;
}
int Socket::_recv( char *buf, int max ) {
#ifdef MONGO_SSL
if ( _ssl ){
return SSL_read( _ssl.get() , buf , max );
}
#endif
return ::recv( _fd , buf , max , portRecvFlags );
}
void Socket::setTimeout( double secs ) {
struct timeval tv;
tv.tv_sec = (int)secs;
tv.tv_usec = (int)((long long)(secs*1000*1000) % (1000*1000));
bool report = logLevel > 3; // solaris doesn't provide these
DEV report = true;
bool ok = setsockopt(_fd, SOL_SOCKET, SO_RCVTIMEO, (char *) &tv, sizeof(tv) ) == 0;
if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl;
ok = setsockopt(_fd, SOL_SOCKET, SO_SNDTIMEO, (char *) &tv, sizeof(tv) ) == 0;
DEV if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl;
}
#if defined(_WIN32)
struct WinsockInit {
WinsockInit() {
WSADATA d;
if ( WSAStartup(MAKEWORD(2,2), &d) != 0 ) {
out() << "ERROR: wsastartup failed " << errnoWithDescription() << endl;
problem() << "ERROR: wsastartup failed " << errnoWithDescription() << endl;
dbexit( EXIT_NTSERVICE_ERROR );
}
}
} winsock_init;
#endif
} // namespace mongo