From d0ed8f9874f6693cdcdac0f5e2532705f607d566 Mon Sep 17 00:00:00 2001 From: Tim Felgentreff <timfelgentreff@gmail.com> Date: Sun, 13 Dec 2020 11:41:28 +0100 Subject: [PATCH] fix various conditions in which the online service message buffer would become corrupted --- src/network/online_service.cpp | 72 ++++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 12 deletions(-) diff --git a/src/network/online_service.cpp b/src/network/online_service.cpp index 03547860d..6a797ae9b 100644 --- a/src/network/online_service.cpp +++ b/src/network/online_service.cpp @@ -97,8 +97,8 @@ public: ~BNCSInputStream() {}; std::string readString() { - if (received_bytes == 0) { - return NULL; + if (received_bytes - pos <= 0) { + return ""; } std::stringstream strstr; int i = pos; @@ -124,6 +124,19 @@ public: return stringlist; }; + std::vector<std::string> readStringlist(int cnt) { + std::vector<std::string> stringlist; + for (; cnt >= 0; cnt--) { + std::string nxt = readString(); + if (nxt.empty()) { + break; + } else { + stringlist.push_back(nxt); + } + } + return stringlist; + }; + uint8_t read8() { uint8_t byte = buffer[pos]; consumeData(1); @@ -188,12 +201,24 @@ public: // (UINT8) Message ID // (UINT16) Message length, including this header // (VOID) Message data - received_bytes += this->sock->Recv(buffer + received_bytes, 4 - received_bytes); - if (received_bytes != 4) { + if (received_bytes < 4) { + // in case of retry, we may already have the first 4 bytes + received_bytes += this->sock->Recv(buffer + received_bytes, 4 - received_bytes); + } + if (received_bytes < 4) { + // didn't get the complete header yet + return -1; + } + assert(pos == 0); + uint8_t headerbyte = read8(); + if (headerbyte != 0xff) { + // Likely a bug on our side. We just skip this byte. + debugDump(); + memmove(buffer, buffer + 1, received_bytes - 1); + received_bytes -= 1; + pos = 0; return -1; } - assert(received_bytes == 4); - assert(read8() == 0xff); uint8_t msgId = read8(); uint16_t len = read16(); // we still need to have len in total for this message, so if we have @@ -201,9 +226,9 @@ public: // first 4 bytes that we already consumed, we'll have enough assert(pos == 4); long needed = len - received_bytes; - if (needed != 0) { - assert(needed > 0); - if (len >= bufsize) { + if (needed > 0) { + if (needed >= bufsize + received_bytes) { + // we never shrink the buffer again while we're online buffer = (char*)realloc(buffer, sizeof(char) * len + 1); bufsize = len + 1; } @@ -218,13 +243,27 @@ public: return msgId; }; + void debugDump() { + if (EnableDebugPrint) { + std::cout << "Input stream state: pos(" << pos << "), received_bytes(" << received_bytes << ")" << std::endl; + dump((uint8_t*)buffer, received_bytes); + } + } + void finishMessage() { - received_bytes = 0; + assert(pos <= received_bytes); + received_bytes = received_bytes - pos; + if (received_bytes > 0) { + // move the remaining received bytes to the start of the buffer, to + // be used for the next message + memmove(buffer, buffer + pos, received_bytes); + } pos = 0; } private: void consumeData(int bytes) { + assert(pos + bytes <= received_bytes); pos += bytes; } @@ -1470,7 +1509,7 @@ private: assert(cnt == 1); uint32_t keys = ctx->getMsgIStream()->read32(); uint32_t reqId = ctx->getMsgIStream()->read32(); - std::vector<std::string> values = ctx->getMsgIStream()->readStringlist(); + std::vector<std::string> values = ctx->getMsgIStream()->readStringlist(keys); ctx->reportUserdata(reqId, values); } @@ -1575,6 +1614,7 @@ class S2C_ENTERCHAT : public NetworkState { std::string error = std::string("Expected SID_ENTERCHAT, got msg id "); error += std::to_string(msg); ctx->setState(new DisconnectedState(error)); + return; } DebugPrint("TCP Recv: 0x0a ENTERCHAT\n"); @@ -1636,6 +1676,7 @@ class S2C_CREATEACCOUNT2 : public NetworkState { std::string error = std::string("Expected SID_CREATEACCOUNT2, got msg id "); error += std::to_string(msg); ctx->setState(new DisconnectedState(error)); + return; } DebugPrint("TCP Recv: 0x3d CREATEACCOUNT\n"); @@ -1714,7 +1755,10 @@ class S2C_LOGONRESPONSE2 : public NetworkState { if (msg != 0x3a) { std::string error = std::string("Expected SID_LOGONRESPONSE2, got msg id "); error += std::to_string(msg); - ctx->setState(new DisconnectedState(error)); + ctx->showError(error); + ctx->getMsgIStream()->debugDump(); + ctx->getMsgIStream()->finishMessage(); + return; } DebugPrint("TCP Sent: 0x3a LOGONRESPONSE\n"); @@ -1747,6 +1791,7 @@ class S2C_LOGONRESPONSE2 : public NetworkState { return; default: ctx->setState(new DisconnectedState("unknown logon response")); + return; } } } @@ -1816,6 +1861,7 @@ class S2C_SID_AUTH_CHECK : public NetworkState { std::string error = std::string("Expected SID_AUTH_CHECK, got msg id "); error += std::to_string(msg); ctx->setState(new DisconnectedState(error)); + return; } DebugPrint("TCP Recv: 0x51 AUTH_CHECK\n"); @@ -1877,6 +1923,7 @@ class S2C_SID_AUTH_INFO : public NetworkState { std::string error = std::string("Expected SID_AUTH_INFO, got msg id "); error += std::to_string(msg); ctx->setState(new DisconnectedState(error)); + return; } DebugPrint("TCP Recv: 0x50 AUTH_INFO\n"); @@ -1941,6 +1988,7 @@ class S2C_SID_PING : public NetworkState { std::string error = std::string("Expected SID_PING, got msg id "); error += std::to_string(msg); ctx->setState(new DisconnectedState(error)); + return; } DebugPrint("TCP Recv: 0x25 PING\n"); uint32_t pingValue = ctx->getMsgIStream()->read32();