0
0
mirror of https://github.com/mongodb/mongo.git synced 2024-12-01 09:32:32 +01:00

SERVER-19420 implement connection hook API in NetworkInterfaceASIO

This commit is contained in:
Adam Midvidy 2015-08-14 11:26:02 -04:00
parent 4c61da5028
commit e6ddd3da54
13 changed files with 672 additions and 204 deletions

View File

@ -72,7 +72,6 @@ env.Library(target='network_test_env',
env.Library(
target='network_interface_asio',
source=[
'async_mock_stream_factory.cpp',
'async_secure_stream.cpp',
'async_secure_stream_factory.cpp',
'async_stream.cpp',
@ -96,6 +95,7 @@ env.Library(
env.CppUnitTest(
target='network_interface_asio_test',
source=[
'async_mock_stream_factory.cpp',
'network_interface_asio_test.cpp',
],
LIBDEPS=[

View File

@ -32,12 +32,18 @@
#include "mongo/executor/async_mock_stream_factory.h"
#include <exception>
#include <iterator>
#include <system_error>
#include "mongo/rpc/command_reply_builder.h"
#include "mongo/rpc/factory.h"
#include "mongo/rpc/legacy_reply_builder.h"
#include "mongo/rpc/request_interface.h"
#include "mongo/stdx/memory.h"
#include "mongo/util/assert_util.h"
#include "mongo/util/log.h"
#include "mongo/util/net/message.h"
namespace mongo {
namespace executor {
@ -86,9 +92,9 @@ void AsyncMockStreamFactory::MockStream::connect(asio::ip::tcp::resolver::iterat
ConnectHandler&& connectHandler) {
{
stdx::unique_lock<stdx::mutex> lk(_mutex);
log() << "connect() for: " << _target;
_block_inlock(&lk);
// Block before returning from connect.
_block_inlock(kBlockedBeforeConnect, &lk);
}
_io_service->post([connectHandler, endpoints] { connectHandler(std::error_code()); });
}
@ -96,15 +102,13 @@ void AsyncMockStreamFactory::MockStream::connect(asio::ip::tcp::resolver::iterat
void AsyncMockStreamFactory::MockStream::write(asio::const_buffer buf,
StreamHandler&& writeHandler) {
stdx::unique_lock<stdx::mutex> lk(_mutex);
log() << "write() for: " << _target;
auto begin = asio::buffer_cast<const uint8_t*>(buf);
auto size = asio::buffer_size(buf);
_writeQueue.push({begin, begin + size});
// Block after data is written.
_block_inlock(&lk);
_block_inlock(kBlockedAfterWrite, &lk);
lk.unlock();
_io_service->post([writeHandler, size] { writeHandler(std::error_code(), size); });
@ -113,10 +117,8 @@ void AsyncMockStreamFactory::MockStream::write(asio::const_buffer buf,
void AsyncMockStreamFactory::MockStream::read(asio::mutable_buffer buf,
StreamHandler&& readHandler) {
stdx::unique_lock<stdx::mutex> lk(_mutex);
log() << "read() for: " << _target;
// Block before data is read.
_block_inlock(&lk);
_block_inlock(kBlockedBeforeRead, &lk);
auto nextRead = std::move(_readQueue.front());
_readQueue.pop();
@ -138,26 +140,26 @@ void AsyncMockStreamFactory::MockStream::read(asio::mutable_buffer buf,
void AsyncMockStreamFactory::MockStream::pushRead(std::vector<uint8_t> toRead) {
stdx::unique_lock<stdx::mutex> lk(_mutex);
invariant(_blocked);
invariant(_state != kRunning);
_readQueue.emplace(std::move(toRead));
}
std::vector<uint8_t> AsyncMockStreamFactory::MockStream::popWrite() {
stdx::unique_lock<stdx::mutex> lk(_mutex);
invariant(_blocked);
invariant(_state != kRunning);
auto nextWrite = std::move(_writeQueue.front());
_writeQueue.pop();
return nextWrite;
}
void AsyncMockStreamFactory::MockStream::_block_inlock(stdx::unique_lock<stdx::mutex>* lk) {
log() << "blocking in stream for: " << _target;
invariant(!_blocked);
_blocked = true;
void AsyncMockStreamFactory::MockStream::_block_inlock(StreamState state,
stdx::unique_lock<stdx::mutex>* lk) {
invariant(_state == kRunning);
_state = state;
lk->unlock();
_cv.notify_one();
lk->lock();
_cv.wait(*lk, [this]() { return !_blocked; });
_cv.wait(*lk, [this]() { return _state == kRunning; });
}
void AsyncMockStreamFactory::MockStream::unblock() {
@ -166,18 +168,79 @@ void AsyncMockStreamFactory::MockStream::unblock() {
}
void AsyncMockStreamFactory::MockStream::_unblock_inlock(stdx::unique_lock<stdx::mutex>* lk) {
log() << "unblocking stream for: " << _target;
invariant(_blocked);
_blocked = false;
invariant(_state != kRunning);
_state = kRunning;
lk->unlock();
_cv.notify_one();
lk->lock();
}
void AsyncMockStreamFactory::MockStream::waitUntilBlocked() {
auto AsyncMockStreamFactory::MockStream::waitUntilBlocked() -> StreamState {
stdx::unique_lock<stdx::mutex> lk(_mutex);
log() << "waiting until stream for " << _target << " has blocked";
_cv.wait(lk, [this]() { return _blocked; });
_cv.wait(lk, [this]() { return _state != kRunning; });
return _state;
}
HostAndPort AsyncMockStreamFactory::MockStream::target() {
return _target;
}
void AsyncMockStreamFactory::MockStream::simulateServer(
rpc::Protocol proto,
const stdx::function<RemoteCommandResponse(RemoteCommandRequest)> replyFunc) {
std::exception_ptr ex;
uint32_t messageId = 0;
RemoteCommandResponse resp;
{
WriteEvent write{this};
std::vector<uint8_t> messageData = popWrite();
Message msg(messageData.data(), false);
auto parsedRequest = rpc::makeRequest(&msg);
ASSERT(parsedRequest->getProtocol() == proto);
RemoteCommandRequest rcr(target(), *parsedRequest);
messageId = msg.header().getId();
// So we can allow ASSERTs in replyFunc, we capture any exceptions, but rethrow
// them later to prevent deadlock
try {
resp = replyFunc(std::move(rcr));
} catch (...) {
ex = std::current_exception();
}
}
auto replyBuilder = rpc::makeReplyBuilder(proto);
replyBuilder->setMetadata(resp.metadata);
replyBuilder->setCommandReply(resp.data);
auto replyMsg = replyBuilder->done();
replyMsg->header().setResponseTo(messageId);
{
// The first read will be for the header.
ReadEvent read{this};
auto hdrBytes = reinterpret_cast<const uint8_t*>(replyMsg->header().view2ptr());
pushRead({hdrBytes, hdrBytes + sizeof(MSGHEADER::Value)});
}
{
// The second read will be for the message data.
ReadEvent read{this};
auto dataBytes = reinterpret_cast<const uint8_t*>(replyMsg->buf());
auto pastHeader = dataBytes;
std::advance(pastHeader, sizeof(MSGHEADER::Value));
pushRead({pastHeader, dataBytes + static_cast<std::size_t>(replyMsg->size())});
}
if (ex) {
// Rethrow ASSERTS after the NIA completes it's Write-Read-Read sequence.
std::rethrow_exception(ex);
}
}
} // namespace executor

View File

@ -36,8 +36,13 @@
#include "mongo/executor/async_stream_factory_interface.h"
#include "mongo/executor/async_stream_interface.h"
#include "mongo/stdx/mutex.h"
#include "mongo/executor/remote_command_request.h"
#include "mongo/executor/remote_command_response.h"
#include "mongo/rpc/protocol.h"
#include "mongo/stdx/condition_variable.h"
#include "mongo/stdx/functional.h"
#include "mongo/stdx/mutex.h"
#include "mongo/unittest/unittest.h"
#include "mongo/util/net/hostandport.h"
namespace mongo {
@ -57,6 +62,15 @@ public:
MockStream(asio::io_service* io_service,
AsyncMockStreamFactory* factory,
const HostAndPort& target);
// Use unscoped enum so we can specialize on it
enum StreamState {
kRunning,
kBlockedBeforeConnect,
kBlockedBeforeRead,
kBlockedAfterWrite,
};
~MockStream();
void connect(asio::ip::tcp::resolver::iterator endpoints,
@ -64,27 +78,32 @@ public:
void write(asio::const_buffer buf, StreamHandler&& writeHandler) override;
void read(asio::mutable_buffer buf, StreamHandler&& readHandler) override;
void waitUntilBlocked();
HostAndPort target();
StreamState waitUntilBlocked();
std::vector<uint8_t> popWrite();
void pushRead(std::vector<uint8_t> toRead);
void unblock();
void simulateServer(
rpc::Protocol proto,
const stdx::function<RemoteCommandResponse(RemoteCommandRequest)> replyFunc);
private:
void _unblock_inlock(stdx::unique_lock<stdx::mutex>* lk);
void _block_inlock(stdx::unique_lock<stdx::mutex>* lk);
void _block_inlock(StreamState state, stdx::unique_lock<stdx::mutex>* lk);
asio::io_service* _io_service;
AsyncMockStreamFactory* _factory;
HostAndPort _target;
stdx::mutex _mutex;
stdx::condition_variable _cv;
bool _blocked{false};
StreamState _state{kRunning};
std::queue<std::vector<uint8_t>> _readQueue;
std::queue<std::vector<uint8_t>> _writeQueue;
@ -102,5 +121,33 @@ private:
std::unordered_map<HostAndPort, MockStream*> _streams;
};
template <int EventType>
class StreamEvent {
public:
StreamEvent(AsyncMockStreamFactory::MockStream* stream) : _stream(stream) {
ASSERT(stream->waitUntilBlocked() == EventType);
}
void skip() {
_stream->unblock();
skipped = true;
}
~StreamEvent() {
if (!skipped) {
skip();
}
}
private:
bool skipped = false;
AsyncMockStreamFactory::MockStream* _stream = nullptr;
};
using ReadEvent = StreamEvent<AsyncMockStreamFactory::MockStream::StreamState::kBlockedBeforeRead>;
using WriteEvent = StreamEvent<AsyncMockStreamFactory::MockStream::StreamState::kBlockedAfterWrite>;
using ConnectEvent =
StreamEvent<AsyncMockStreamFactory::MockStream::StreamState::kBlockedBeforeConnect>;
} // namespace executor
} // namespace mongo

View File

@ -48,7 +48,13 @@ namespace executor {
NetworkInterfaceASIO::NetworkInterfaceASIO(
std::unique_ptr<AsyncStreamFactoryInterface> streamFactory)
: NetworkInterfaceASIO(std::move(streamFactory), nullptr) {}
NetworkInterfaceASIO::NetworkInterfaceASIO(
std::unique_ptr<AsyncStreamFactoryInterface> streamFactory,
std::unique_ptr<NetworkConnectionHook> networkConnectionHook)
: _io_service(),
_hook(std::move(networkConnectionHook)),
_resolver(_io_service),
_state(State::kReady),
_streamFactory(std::move(streamFactory)),

View File

@ -38,6 +38,7 @@
#include "mongo/base/status.h"
#include "mongo/base/system_error.h"
#include "mongo/executor/network_connection_hook.h"
#include "mongo/executor/network_interface.h"
#include "mongo/executor/remote_command_request.h"
#include "mongo/executor/remote_command_response.h"
@ -54,7 +55,6 @@ namespace executor {
class AsyncStreamFactoryInterface;
class AsyncStreamInterface;
class NetworkConnectionHook;
/**
* Implementation of the replication system's network interface using Christopher
@ -62,6 +62,8 @@ class NetworkConnectionHook;
*/
class NetworkInterfaceASIO final : public NetworkInterface {
public:
NetworkInterfaceASIO(std::unique_ptr<AsyncStreamFactoryInterface> streamFactory,
std::unique_ptr<NetworkConnectionHook> networkConnectionHook);
NetworkInterfaceASIO(std::unique_ptr<AsyncStreamFactoryInterface> streamFactory);
std::string getDiagnosticString() override;
std::string getHostName() override;
@ -239,6 +241,7 @@ private:
void _setupSocket(AsyncOp* op, asio::ip::tcp::resolver::iterator endpoints);
void _runIsMaster(AsyncOp* op);
void _runConnectionHook(AsyncOp* op);
void _authenticate(AsyncOp* op);
// Communication state machine
@ -254,6 +257,8 @@ private:
asio::io_service _io_service;
stdx::thread _serviceRunner;
std::unique_ptr<NetworkConnectionHook> _hook;
asio::ip::tcp::resolver _resolver;
std::atomic<State> _state;

View File

@ -62,33 +62,40 @@ void NetworkInterfaceASIO::_runIsMaster(AsyncOp* op) {
// Callback to parse protocol information out of received ismaster response
auto parseIsMaster = [this, op]() {
try {
auto commandReply = rpc::makeReply(&(op->command().toRecv()));
BSONObj isMasterReply = commandReply->getCommandReply();
auto protocolSet = rpc::parseProtocolSetFromIsMasterReply(isMasterReply);
if (!protocolSet.isOK())
return _completeOperation(op, protocolSet.getStatus());
op->connection().setServerProtocols(protocolSet.getValue());
// Set the operation protocol
auto negotiatedProtocol = rpc::negotiate(op->connection().serverProtocols(),
op->connection().clientProtocols());
if (!negotiatedProtocol.isOK()) {
return _completeOperation(op, negotiatedProtocol.getStatus());
}
op->setOperationProtocol(negotiatedProtocol.getValue());
// Advance the state machine
return _authenticate(op);
} catch (...) {
// makeReply will throw if the reply was invalid.
return _completeOperation(op, exceptionToStatus());
auto swCommandReply = op->command().response(rpc::Protocol::kOpQuery, now());
if (!swCommandReply.isOK()) {
return _completeOperation(op, swCommandReply.getStatus());
}
auto commandReply = std::move(swCommandReply.getValue());
if (_hook) {
// Run the validation hook.
auto validHost = _hook->validateHost(op->request().target, commandReply);
if (!validHost.isOK()) {
return _completeOperation(op, validHost);
}
}
auto protocolSet = rpc::parseProtocolSetFromIsMasterReply(commandReply.data);
if (!protocolSet.isOK())
return _completeOperation(op, protocolSet.getStatus());
op->connection().setServerProtocols(protocolSet.getValue());
// Set the operation protocol
auto negotiatedProtocol =
rpc::negotiate(op->connection().serverProtocols(), op->connection().clientProtocols());
if (!negotiatedProtocol.isOK()) {
return _completeOperation(op, negotiatedProtocol.getStatus());
}
op->setOperationProtocol(negotiatedProtocol.getValue());
return _authenticate(op);
};
_asyncRunCommand(&cmd,
@ -105,7 +112,7 @@ void NetworkInterfaceASIO::_authenticate(AsyncOp* op) {
// This check is sufficient to see if auth is enabled on the system,
// and avoids creating dependencies on deeper, less accessible auth code.
if (!isInternalAuthSet()) {
return asio::post(_io_service, [this, op]() { _beginCommunication(op); });
return _runConnectionHook(op);
}
// We will only have a valid clientName if SSL is enabled.
@ -136,7 +143,7 @@ void NetworkInterfaceASIO::_authenticate(AsyncOp* op) {
auto authHook = [this, op](auth::AuthResponse response) {
if (!response.isOK())
return _completeOperation(op, response);
return _beginCommunication(op);
return _runConnectionHook(op);
};
auto params = getInternalUserAuthParamsWithFallback();

View File

@ -41,6 +41,7 @@
#include "mongo/rpc/factory.h"
#include "mongo/rpc/protocol.h"
#include "mongo/rpc/reply_interface.h"
#include "mongo/rpc/request_builder_interface.h"
#include "mongo/stdx/memory.h"
#include "mongo/util/assert_util.h"
#include "mongo/util/log.h"
@ -271,5 +272,48 @@ void NetworkInterfaceASIO::_asyncRunCommand(AsyncCommand* cmd, NetworkOpHandler
asyncSendMessage(cmd->conn().stream(), &cmd->toSend(), std::move(sendMessageCallback));
}
void NetworkInterfaceASIO::_runConnectionHook(AsyncOp* op) {
if (!_hook) {
return _beginCommunication(op);
}
auto swOptionalRequest = _hook->makeRequest(op->request().target);
if (!swOptionalRequest.isOK()) {
return _completeOperation(op, swOptionalRequest.getStatus());
}
auto optionalRequest = std::move(swOptionalRequest.getValue());
if (optionalRequest == boost::none) {
return _beginCommunication(op);
}
auto& cmd = op->beginCommand(*optionalRequest, op->operationProtocol(), now());
auto finishHook = [this, op]() {
auto response = op->command().response(op->operationProtocol(), now());
if (!response.isOK()) {
return _completeOperation(op, response.getStatus());
}
auto handleStatus =
_hook->handleReply(op->request().target, std::move(response.getValue()));
if (!handleStatus.isOK()) {
return _completeOperation(op, handleStatus);
}
return _beginCommunication(op);
};
return _asyncRunCommand(&cmd,
[this, op, finishHook](std::error_code ec, std::size_t bytes) {
_validateAndRun(op, ec, finishHook);
});
}
} // namespace executor
} // namespace mongo

View File

@ -30,15 +30,15 @@
#include "mongo/platform/basic.h"
#include <boost/optional.hpp>
#include "mongo/base/status_with.h"
#include "mongo/db/jsobj.h"
#include "mongo/db/wire_version.h"
#include "mongo/executor/async_mock_stream_factory.h"
#include "mongo/executor/network_interface_asio.h"
#include "mongo/rpc/command_reply_builder.h"
#include "mongo/rpc/factory.h"
#include "mongo/executor/test_network_connection_hook.h"
#include "mongo/rpc/legacy_reply_builder.h"
#include "mongo/rpc/protocol.h"
#include "mongo/rpc/request_interface.h"
#include "mongo/stdx/future.h"
#include "mongo/stdx/memory.h"
#include "mongo/unittest/unittest.h"
@ -49,6 +49,8 @@ namespace mongo {
namespace executor {
namespace {
HostAndPort testHost{"localhost", 20000};
class NetworkInterfaceASIOTest : public mongo::unittest::Test {
public:
void setUp() override {
@ -73,7 +75,11 @@ public:
return *_streamFactory;
}
private:
void simulateServerReply(AsyncMockStreamFactory::MockStream* stream,
rpc::Protocol proto,
const stdx::function<RemoteCommandResponse(RemoteCommandRequest)>) {}
protected:
AsyncMockStreamFactory* _streamFactory;
std::unique_ptr<NetworkInterfaceASIO> _net;
};
@ -104,118 +110,39 @@ TEST_F(NetworkInterfaceASIOTest, StartCommand) {
auto stream = streamFactory().blockUntilStreamExists(testHost);
// Allow stream to connect.
{
stream->waitUntilBlocked();
auto guard = MakeGuard([&] { stream->unblock(); });
}
ConnectEvent{stream}.skip();
log() << "connected";
// simulate isMaster reply.
stream->simulateServer(
rpc::Protocol::kOpQuery,
[](RemoteCommandRequest request) -> RemoteCommandResponse {
ASSERT_EQ(std::string{request.cmdObj.firstElementFieldName()}, "isMaster");
ASSERT_EQ(request.dbname, "admin");
uint32_t isMasterMsgId = 0;
{
stream->waitUntilBlocked();
auto guard = MakeGuard([&] { stream->unblock(); });
log() << "NIA blocked after writing isMaster request";
// Check that an isMaster has been run on the stream
std::vector<uint8_t> messageData = stream->popWrite();
Message msg(messageData.data(), false);
auto request = rpc::makeRequest(&msg);
ASSERT_EQ(request->getCommandName(), "isMaster");
ASSERT_EQ(request->getDatabase(), "admin");
isMasterMsgId = msg.header().getId();
// Check that we used OP_QUERY.
ASSERT(request->getProtocol() == rpc::Protocol::kOpQuery);
}
rpc::LegacyReplyBuilder replyBuilder;
replyBuilder.setMetadata(BSONObj());
replyBuilder.setCommandReply(BSON("minWireVersion" << mongo::minWireVersion << "maxWireVersion"
<< mongo::maxWireVersion));
auto replyMsg = replyBuilder.done();
replyMsg->header().setResponseTo(isMasterMsgId);
{
stream->waitUntilBlocked();
auto guard = MakeGuard([&] { stream->unblock(); });
log() << "NIA blocked before reading isMaster reply header";
// write out the full message now, even though another read() call will read the rest.
auto hdrBytes = reinterpret_cast<const uint8_t*>(replyMsg->header().view2ptr());
stream->pushRead({hdrBytes, hdrBytes + sizeof(MSGHEADER::Value)});
auto dataBytes = reinterpret_cast<const uint8_t*>(replyMsg->buf());
auto pastHeader = dataBytes;
std::advance(pastHeader, sizeof(MSGHEADER::Value)); // skip the header this time
stream->pushRead({pastHeader, dataBytes + static_cast<std::size_t>(replyMsg->size())});
}
{
stream->waitUntilBlocked();
auto guard = MakeGuard([&] { stream->unblock(); });
log() << "NIA blocked before reading isMaster reply data";
}
RemoteCommandResponse response;
response.data = BSON("minWireVersion" << mongo::minWireVersion << "maxWireVersion"
<< mongo::maxWireVersion);
return response;
});
auto expectedMetadata = BSON("meep"
<< "beep");
auto expectedCommandReply = BSON("boop"
<< "bop"
<< "ok" << 1.0);
auto expectedMetadata = BSON("meep"
<< "beep");
{
stream->waitUntilBlocked();
auto guard = MakeGuard([&] { stream->unblock(); });
// simulate user command
stream->simulateServer(rpc::Protocol::kOpCommandV1,
[&](RemoteCommandRequest request) -> RemoteCommandResponse {
ASSERT_EQ(std::string{request.cmdObj.firstElementFieldName()},
"foo");
ASSERT_EQ(request.dbname, "testDB");
log() << "blocked after write(), reading user command request";
std::vector<uint8_t> messageData{stream->popWrite()};
Message msg(messageData.data(), false);
auto request = rpc::makeRequest(&msg);
// the command we requested should be running.
ASSERT_EQ(request->getCommandName(), "foo");
ASSERT_EQ(request->getDatabase(), "testDB");
// we should be using op command given our previous isMaster reply.
ASSERT(request->getProtocol() == rpc::Protocol::kOpCommandV1);
rpc::CommandReplyBuilder replyBuilder;
replyBuilder.setMetadata(expectedMetadata).setCommandReply(expectedCommandReply);
auto replyMsg = replyBuilder.done();
replyMsg->header().setResponseTo(msg.header().getId());
// write out the full message now, even though another read() call will read the rest.
auto hdrBytes = reinterpret_cast<const uint8_t*>(replyMsg->header().view2ptr());
stream->pushRead({hdrBytes, hdrBytes + sizeof(MSGHEADER::Value)});
auto dataBytes = reinterpret_cast<const uint8_t*>(replyMsg->buf());
auto pastHeader = dataBytes;
std::advance(pastHeader, sizeof(MSGHEADER::Value)); // skip the header this time
stream->pushRead({pastHeader, dataBytes + static_cast<std::size_t>(replyMsg->size())});
}
{
stream->waitUntilBlocked();
auto guard = MakeGuard([&] { stream->unblock(); });
}
{
stream->waitUntilBlocked();
auto guard = MakeGuard([&] { stream->unblock(); });
}
RemoteCommandResponse response;
response.data = expectedCommandReply;
response.metadata = expectedMetadata;
return response;
});
auto res = fut.get();
@ -224,6 +151,300 @@ TEST_F(NetworkInterfaceASIOTest, StartCommand) {
ASSERT_EQ(res.metadata, expectedMetadata);
}
class NetworkInterfaceASIOConnectionHookTest : public NetworkInterfaceASIOTest {
public:
void setUp() override {}
void start(std::unique_ptr<NetworkConnectionHook> hook) {
auto factory = stdx::make_unique<AsyncMockStreamFactory>();
// keep unowned pointer, but pass ownership to NIA
_streamFactory = factory.get();
_net = stdx::make_unique<NetworkInterfaceASIO>(std::move(factory), std::move(hook));
_net->startup();
}
};
template <typename T>
void assertThrowsStatus(stdx::future<T>&& fut, const Status& s) {
ASSERT([&] {
try {
std::forward<stdx::future<T>>(fut).get();
return false;
} catch (const DBException& ex) {
return ex.toStatus() == s;
}
}());
}
TEST_F(NetworkInterfaceASIOConnectionHookTest, ValidateHostInvalid) {
bool validateCalled = false;
bool hostCorrect = false;
bool isMasterReplyCorrect = false;
bool makeRequestCalled = false;
bool handleReplyCalled = false;
auto validationFailedStatus = Status(ErrorCodes::AlreadyInitialized, "blahhhhh");
start(makeTestHook(
[&](const HostAndPort& remoteHost, const RemoteCommandResponse& isMasterReply) {
validateCalled = true;
hostCorrect = (remoteHost == testHost);
isMasterReplyCorrect = (isMasterReply.data["TESTKEY"].str() == "TESTVALUE");
return validationFailedStatus;
},
[&](const HostAndPort& remoteHost) -> StatusWith<boost::optional<RemoteCommandRequest>> {
makeRequestCalled = true;
return {boost::none};
},
[&](const HostAndPort& remoteHost, RemoteCommandResponse&& response) {
handleReplyCalled = true;
return Status::OK();
}));
stdx::promise<RemoteCommandResponse> done;
auto doneFuture = done.get_future();
net().startCommand({},
{testHost,
"blah",
BSON("foo"
<< "bar")},
[&](StatusWith<RemoteCommandResponse> result) {
try {
done.set_value(uassertStatusOK(result));
} catch (...) {
done.set_exception(std::current_exception());
}
});
auto stream = streamFactory().blockUntilStreamExists(testHost);
ConnectEvent{stream}.skip();
// simulate isMaster reply.
stream->simulateServer(rpc::Protocol::kOpQuery,
[](RemoteCommandRequest request) -> RemoteCommandResponse {
RemoteCommandResponse response;
response.data = BSON("minWireVersion"
<< mongo::minWireVersion << "maxWireVersion"
<< mongo::maxWireVersion << "TESTKEY"
<< "TESTVALUE");
return response;
});
// we should stop here.
assertThrowsStatus(std::move(doneFuture), validationFailedStatus);
ASSERT(validateCalled);
ASSERT(hostCorrect);
ASSERT(isMasterReplyCorrect);
ASSERT(!makeRequestCalled);
ASSERT(!handleReplyCalled);
}
TEST_F(NetworkInterfaceASIOConnectionHookTest, MakeRequestReturnsError) {
bool makeRequestCalled = false;
bool handleReplyCalled = false;
Status makeRequestError{ErrorCodes::DBPathInUse, "bloooh"};
start(makeTestHook(
[&](const HostAndPort& remoteHost, const RemoteCommandResponse& isMasterReply)
-> Status { return Status::OK(); },
[&](const HostAndPort& remoteHost) -> StatusWith<boost::optional<RemoteCommandRequest>> {
makeRequestCalled = true;
return makeRequestError;
},
[&](const HostAndPort& remoteHost, RemoteCommandResponse&& response) {
handleReplyCalled = true;
return Status::OK();
}));
stdx::promise<RemoteCommandResponse> done;
auto doneFuture = done.get_future();
net().startCommand({},
{testHost,
"blah",
BSON("foo"
<< "bar")},
[&](StatusWith<RemoteCommandResponse> result) {
try {
done.set_value(uassertStatusOK(result));
} catch (...) {
done.set_exception(std::current_exception());
}
});
auto stream = streamFactory().blockUntilStreamExists(testHost);
ConnectEvent{stream}.skip();
// simulate isMaster reply.
stream->simulateServer(rpc::Protocol::kOpQuery,
[](RemoteCommandRequest request) -> RemoteCommandResponse {
RemoteCommandResponse response;
response.data = BSON("minWireVersion" << mongo::minWireVersion
<< "maxWireVersion"
<< mongo::maxWireVersion);
return response;
});
// We should stop here.
assertThrowsStatus(std::move(doneFuture), makeRequestError);
ASSERT(makeRequestCalled);
ASSERT(!handleReplyCalled);
}
TEST_F(NetworkInterfaceASIOConnectionHookTest, MakeRequestReturnsNone) {
bool makeRequestCalled = false;
bool handleReplyCalled = false;
start(makeTestHook(
[&](const HostAndPort& remoteHost, const RemoteCommandResponse& isMasterReply)
-> Status { return Status::OK(); },
[&](const HostAndPort& remoteHost) -> StatusWith<boost::optional<RemoteCommandRequest>> {
makeRequestCalled = true;
return {boost::none};
},
[&](const HostAndPort& remoteHost, RemoteCommandResponse&& response) {
handleReplyCalled = true;
return Status::OK();
}));
stdx::promise<RemoteCommandResponse> done;
auto doneFuture = done.get_future();
auto commandRequest = BSON("foo"
<< "bar");
net().startCommand({},
{testHost, "blah", commandRequest},
[&](StatusWith<RemoteCommandResponse> result) {
try {
done.set_value(uassertStatusOK(result));
} catch (...) {
done.set_exception(std::current_exception());
}
});
auto stream = streamFactory().blockUntilStreamExists(testHost);
ConnectEvent{stream}.skip();
// simulate isMaster reply.
stream->simulateServer(rpc::Protocol::kOpQuery,
[](RemoteCommandRequest request) -> RemoteCommandResponse {
RemoteCommandResponse response;
response.data = BSON("minWireVersion" << mongo::minWireVersion
<< "maxWireVersion"
<< mongo::maxWireVersion);
return response;
});
auto commandReply = BSON("foo"
<< "boo"
<< "ok" << 1.0);
auto metadata = BSON("aaa"
<< "bbb");
// Simulate user command.
stream->simulateServer(rpc::Protocol::kOpCommandV1,
[&](RemoteCommandRequest request) -> RemoteCommandResponse {
ASSERT_EQ(commandRequest, request.cmdObj);
RemoteCommandResponse response;
response.data = commandReply;
response.metadata = metadata;
return response;
});
// We should get back the reply now.
auto reply = doneFuture.get();
ASSERT_EQ(reply.data, commandReply);
ASSERT_EQ(reply.metadata, metadata);
}
TEST_F(NetworkInterfaceASIOConnectionHookTest, HandleReplyReturnsError) {
bool makeRequestCalled = false;
bool handleReplyCalled = false;
bool handleReplyArgumentCorrect = false;
BSONObj hookCommandRequest = BSON("1ddd"
<< "fff");
BSONObj hookRequestMetadata = BSON("wdwd" << 1212);
BSONObj hookCommandReply = BSON("blah"
<< "blah"
<< "ok" << 1.0);
BSONObj hookReplyMetadata = BSON("1111" << 2222);
Status handleReplyError{ErrorCodes::AuthSchemaIncompatible, "daowdjkpowkdjpow"};
start(makeTestHook(
[&](const HostAndPort& remoteHost, const RemoteCommandResponse& isMasterReply)
-> Status { return Status::OK(); },
[&](const HostAndPort& remoteHost) -> StatusWith<boost::optional<RemoteCommandRequest>> {
makeRequestCalled = true;
return {boost::make_optional<RemoteCommandRequest>(
{testHost, "foo", hookCommandRequest, hookRequestMetadata})};
},
[&](const HostAndPort& remoteHost, RemoteCommandResponse&& response) {
handleReplyCalled = true;
handleReplyArgumentCorrect =
(response.data == hookCommandReply) && (response.metadata == hookReplyMetadata);
return handleReplyError;
}));
stdx::promise<RemoteCommandResponse> done;
auto doneFuture = done.get_future();
auto commandRequest = BSON("foo"
<< "bar");
net().startCommand({},
{testHost, "blah", commandRequest},
[&](StatusWith<RemoteCommandResponse> result) {
try {
done.set_value(uassertStatusOK(result));
} catch (...) {
done.set_exception(std::current_exception());
}
});
auto stream = streamFactory().blockUntilStreamExists(testHost);
ConnectEvent{stream}.skip();
// simulate isMaster reply.
stream->simulateServer(rpc::Protocol::kOpQuery,
[](RemoteCommandRequest request) -> RemoteCommandResponse {
RemoteCommandResponse response;
response.data = BSON("minWireVersion" << mongo::minWireVersion
<< "maxWireVersion"
<< mongo::maxWireVersion);
return response;
});
// Simulate hook reply
stream->simulateServer(rpc::Protocol::kOpCommandV1,
[&](RemoteCommandRequest request) -> RemoteCommandResponse {
ASSERT_EQ(request.cmdObj, hookCommandRequest);
ASSERT_EQ(request.metadata, hookRequestMetadata);
RemoteCommandResponse response;
response.data = hookCommandReply;
response.metadata = hookReplyMetadata;
return response;
});
assertThrowsStatus(std::move(doneFuture), handleReplyError);
ASSERT(makeRequestCalled);
ASSERT(handleReplyCalled);
ASSERT(handleReplyArgumentCorrect);
}
TEST_F(NetworkInterfaceASIOTest, setAlarm) {
stdx::promise<bool> nearFuture;
stdx::future<bool> executed = nearFuture.get_future();

View File

@ -33,9 +33,10 @@
#include <utility>
#include "mongo/base/status.h"
#include "mongo/executor/network_interface.h"
#include "mongo/executor/network_connection_hook.h"
#include "mongo/executor/network_interface.h"
#include "mongo/executor/network_interface_mock.h"
#include "mongo/executor/test_network_connection_hook.h"
#include "mongo/executor/thread_pool_mock.h"
#include "mongo/stdx/memory.h"
#include "mongo/unittest/unittest.h"
@ -44,44 +45,6 @@ namespace mongo {
namespace executor {
namespace {
template <typename ValidateFunc, typename RequestFunc, typename ReplyFunc>
class TestConnectionHook final : public NetworkConnectionHook {
public:
TestConnectionHook(ValidateFunc&& validateFunc,
RequestFunc&& requestFunc,
ReplyFunc&& replyFunc)
: _validateFunc(std::forward<ValidateFunc>(validateFunc)),
_requestFunc(std::forward<RequestFunc>(requestFunc)),
_replyFunc(std::forward<ReplyFunc>(replyFunc)) {}
Status validateHost(const HostAndPort& remoteHost,
const RemoteCommandResponse& isMasterReply) override {
return _validateFunc(remoteHost, isMasterReply);
}
StatusWith<boost::optional<RemoteCommandRequest>> makeRequest(const HostAndPort& remoteHost) {
return _requestFunc(remoteHost);
}
Status handleReply(const HostAndPort& remoteHost, RemoteCommandResponse&& response) {
return _replyFunc(remoteHost, std::move(response));
}
private:
ValidateFunc _validateFunc;
RequestFunc _requestFunc;
ReplyFunc _replyFunc;
};
template <typename Val, typename Req, typename Rep>
static std::unique_ptr<TestConnectionHook<Val, Req, Rep>> makeTestHook(Val&& validateFunc,
Req&& requestFunc,
Rep&& replyFunc) {
return stdx::make_unique<TestConnectionHook<Val, Req, Rep>>(std::forward<Val>(validateFunc),
std::forward<Req>(requestFunc),
std::forward<Rep>(replyFunc));
}
class NetworkInterfaceMockTest : public mongo::unittest::Test {
public:
NetworkInterfaceMockTest() : _net{}, _executor(&_net, 1) {}

View File

@ -31,8 +31,9 @@
#include <string>
#include "mongo/db/jsobj.h"
#include "mongo/util/net/hostandport.h"
#include "mongo/rpc/metadata.h"
#include "mongo/rpc/request_interface.h"
#include "mongo/util/net/hostandport.h"
#include "mongo/util/time_support.h"
namespace mongo {
@ -72,6 +73,15 @@ struct RemoteCommandRequest {
: RemoteCommandRequest(
theTarget, theDbName, theCmdObj, rpc::makeEmptyMetadata(), timeoutMillis) {}
RemoteCommandRequest(const HostAndPort& theTarget,
const rpc::RequestInterface& request,
const Milliseconds timeoutMillis = kNoTimeout)
: RemoteCommandRequest(theTarget,
request.getDatabase().toString(),
request.getCommandArgs(),
request.getMetadata(),
timeoutMillis) {}
std::string toString() const;
HostAndPort target;

View File

@ -30,11 +30,19 @@
#include "mongo/executor/remote_command_response.h"
#include "mongo/rpc/reply_interface.h"
#include "mongo/util/mongoutils/str.h"
namespace mongo {
namespace executor {
// TODO(amidvidy): we currently discard output docs when we use this constructor. We should
// have RCR hold those too, but we need more machinery before that is possible.
RemoteCommandResponse::RemoteCommandResponse(const rpc::ReplyInterface& rpcReply,
Milliseconds millis)
: RemoteCommandResponse(rpcReply.getCommandReply(), rpcReply.getMetadata(), std::move(millis)) {
}
std::string RemoteCommandResponse::toString() const {
return str::stream() << "RemoteResponse -- "
<< " cmd:" << data.toString();

View File

@ -34,6 +34,11 @@
#include "mongo/util/time_support.h"
namespace mongo {
namespace rpc {
class ReplyInterface;
} // namespace rpc
namespace executor {
@ -46,6 +51,8 @@ struct RemoteCommandResponse {
RemoteCommandResponse(BSONObj dataObj, BSONObj metadataObj, Milliseconds millis)
: data(std::move(dataObj)), metadata(std::move(metadataObj)), elapsedMillis(millis) {}
RemoteCommandResponse(const rpc::ReplyInterface& rpcReply, Milliseconds millis);
std::string toString() const;
BSONObj data;

View File

@ -0,0 +1,87 @@
/**
* Copyright (C) 2015 MongoDB Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, version 3,
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
* As a special exception, the copyright holders give permission to link the
* code of portions of this program with the OpenSSL library under certain
* conditions as described in each individual source file and distribute
* linked combinations including the program with the OpenSSL library. You
* must comply with the GNU Affero General Public License in all respects for
* all of the code used other than as permitted herein. If you modify file(s)
* with this exception, you may extend this exception to your version of the
* file(s), but you are not obligated to do so. If you do not wish to do so,
* delete this exception statement from your version. If you delete this
* exception statement from all source files in the program, then also delete
* it in the license file.
*/
#include <boost/optional.hpp>
#include <memory>
#include "mongo/base/status_with.h"
#include "mongo/executor/network_connection_hook.h"
#include "mongo/stdx/memory.h"
namespace mongo {
namespace executor {
/**
* A utility for creating one-off NetworkConnectionHook instances from inline lambdas. This is
* only to be used in testing code, not in production.
*/
template <typename ValidateFunc, typename RequestFunc, typename ReplyFunc>
class TestConnectionHook final : public NetworkConnectionHook {
public:
TestConnectionHook(ValidateFunc&& validateFunc,
RequestFunc&& requestFunc,
ReplyFunc&& replyFunc)
: _validateFunc(std::forward<ValidateFunc>(validateFunc)),
_requestFunc(std::forward<RequestFunc>(requestFunc)),
_replyFunc(std::forward<ReplyFunc>(replyFunc)) {}
Status validateHost(const HostAndPort& remoteHost,
const RemoteCommandResponse& isMasterReply) override {
return _validateFunc(remoteHost, isMasterReply);
}
StatusWith<boost::optional<RemoteCommandRequest>> makeRequest(const HostAndPort& remoteHost) {
return _requestFunc(remoteHost);
}
Status handleReply(const HostAndPort& remoteHost, RemoteCommandResponse&& response) {
return _replyFunc(remoteHost, std::move(response));
}
private:
ValidateFunc _validateFunc;
RequestFunc _requestFunc;
ReplyFunc _replyFunc;
};
/**
* Factory function for TestConnectionHook instances. Needed for template type deduction, so that
* one can instantiate a TestConnectionHook instance without uttering the unutterable (types).
*/
template <typename Val, typename Req, typename Rep>
std::unique_ptr<TestConnectionHook<Val, Req, Rep>> makeTestHook(Val&& validateFunc,
Req&& requestFunc,
Rep&& replyFunc) {
return stdx::make_unique<TestConnectionHook<Val, Req, Rep>>(std::forward<Val>(validateFunc),
std::forward<Req>(requestFunc),
std::forward<Rep>(replyFunc));
}
} // namespace executor
} // namespace mongo