From d0ed8f9874f6693cdcdac0f5e2532705f607d566 Mon Sep 17 00:00:00 2001
From: Tim Felgentreff <timfelgentreff@gmail.com>
Date: Sun, 13 Dec 2020 11:41:28 +0100
Subject: [PATCH] fix various conditions in which the online service message
 buffer would become corrupted

---
 src/network/online_service.cpp | 72 ++++++++++++++++++++++++++++------
 1 file changed, 60 insertions(+), 12 deletions(-)

diff --git a/src/network/online_service.cpp b/src/network/online_service.cpp
index 03547860d..6a797ae9b 100644
--- a/src/network/online_service.cpp
+++ b/src/network/online_service.cpp
@@ -97,8 +97,8 @@ public:
     ~BNCSInputStream() {};
 
     std::string readString() {
-        if (received_bytes == 0) {
-            return NULL;
+        if (received_bytes - pos <= 0) {
+            return "";
         }
         std::stringstream strstr;
         int i = pos;
@@ -124,6 +124,19 @@ public:
         return stringlist;
     };
 
+    std::vector<std::string> readStringlist(int cnt) {
+        std::vector<std::string> stringlist;
+        for (; cnt >= 0; cnt--) {
+            std::string nxt = readString();
+            if (nxt.empty()) {
+                break;
+            } else {
+                stringlist.push_back(nxt);
+            }
+        }
+        return stringlist;
+    };
+
     uint8_t read8() {
         uint8_t byte = buffer[pos];
         consumeData(1);
@@ -188,12 +201,24 @@ public:
         //  (UINT8) Message ID
         // (UINT16) Message length, including this header
         //   (VOID) Message data
-        received_bytes += this->sock->Recv(buffer + received_bytes, 4 - received_bytes);
-        if (received_bytes != 4) {
+        if (received_bytes < 4) {
+            // in case of retry, we may already have the first 4 bytes
+            received_bytes += this->sock->Recv(buffer + received_bytes, 4 - received_bytes);
+        }
+        if (received_bytes < 4) {
+            // didn't get the complete header yet
+            return -1;
+        }
+        assert(pos == 0);
+        uint8_t headerbyte = read8();
+        if (headerbyte != 0xff) {
+            // Likely a bug on our side. We just skip this byte.
+            debugDump();
+            memmove(buffer, buffer + 1, received_bytes - 1);
+            received_bytes -= 1;
+            pos = 0;
             return -1;
         }
-        assert(received_bytes == 4);
-        assert(read8() == 0xff);
         uint8_t msgId = read8();
         uint16_t len = read16();
         // we still need to have len in total for this message, so if we have
@@ -201,9 +226,9 @@ public:
         // first 4 bytes that we already consumed, we'll have enough
         assert(pos == 4);
         long needed = len - received_bytes;
-        if (needed != 0) {
-            assert(needed > 0);
-            if (len >= bufsize) {
+        if (needed > 0) {
+            if (needed >= bufsize + received_bytes) {
+                // we never shrink the buffer again while we're online
                 buffer = (char*)realloc(buffer, sizeof(char) * len + 1);
                 bufsize = len + 1;
             }
@@ -218,13 +243,27 @@ public:
         return msgId;
     };
 
+    void debugDump() {
+        if (EnableDebugPrint) {
+            std::cout << "Input stream state: pos(" << pos << "), received_bytes(" << received_bytes << ")" << std::endl;
+            dump((uint8_t*)buffer, received_bytes);
+        }
+    }
+
     void finishMessage() {
-        received_bytes = 0;
+        assert(pos <= received_bytes);
+        received_bytes = received_bytes - pos;
+        if (received_bytes > 0) {
+            // move the remaining received bytes to the start of the buffer, to
+            // be used for the next message
+            memmove(buffer, buffer + pos, received_bytes);
+        }
         pos = 0;
     }
 
 private:
     void consumeData(int bytes) {
+        assert(pos + bytes <= received_bytes);
         pos += bytes;
     }
 
@@ -1470,7 +1509,7 @@ private:
         assert(cnt == 1);
         uint32_t keys = ctx->getMsgIStream()->read32();
         uint32_t reqId = ctx->getMsgIStream()->read32();
-        std::vector<std::string> values = ctx->getMsgIStream()->readStringlist();
+        std::vector<std::string> values = ctx->getMsgIStream()->readStringlist(keys);
         ctx->reportUserdata(reqId, values);
     }
 
@@ -1575,6 +1614,7 @@ class S2C_ENTERCHAT : public NetworkState {
                 std::string error = std::string("Expected SID_ENTERCHAT, got msg id ");
                 error += std::to_string(msg);
                 ctx->setState(new DisconnectedState(error));
+                return;
             }
             DebugPrint("TCP Recv: 0x0a ENTERCHAT\n");
 
@@ -1636,6 +1676,7 @@ class S2C_CREATEACCOUNT2 : public NetworkState {
                 std::string error = std::string("Expected SID_CREATEACCOUNT2, got msg id ");
                 error += std::to_string(msg);
                 ctx->setState(new DisconnectedState(error));
+                return;
             }
             DebugPrint("TCP Recv: 0x3d CREATEACCOUNT\n");
 
@@ -1714,7 +1755,10 @@ class S2C_LOGONRESPONSE2 : public NetworkState {
             if (msg != 0x3a) {
                 std::string error = std::string("Expected SID_LOGONRESPONSE2, got msg id ");
                 error += std::to_string(msg);
-                ctx->setState(new DisconnectedState(error));
+                ctx->showError(error);
+                ctx->getMsgIStream()->debugDump();
+                ctx->getMsgIStream()->finishMessage();
+                return;
             }
             DebugPrint("TCP Sent: 0x3a LOGONRESPONSE\n");
 
@@ -1747,6 +1791,7 @@ class S2C_LOGONRESPONSE2 : public NetworkState {
                 return;
             default:
                 ctx->setState(new DisconnectedState("unknown logon response"));
+                return;
             }
         }
     }
@@ -1816,6 +1861,7 @@ class S2C_SID_AUTH_CHECK : public NetworkState {
                 std::string error = std::string("Expected SID_AUTH_CHECK, got msg id ");
                 error += std::to_string(msg);
                 ctx->setState(new DisconnectedState(error));
+                return;
             }
             DebugPrint("TCP Recv: 0x51 AUTH_CHECK\n");
 
@@ -1877,6 +1923,7 @@ class S2C_SID_AUTH_INFO : public NetworkState {
                 std::string error = std::string("Expected SID_AUTH_INFO, got msg id ");
                 error += std::to_string(msg);
                 ctx->setState(new DisconnectedState(error));
+                return;
             }
             DebugPrint("TCP Recv: 0x50 AUTH_INFO\n");
 
@@ -1941,6 +1988,7 @@ class S2C_SID_PING : public NetworkState {
                 std::string error = std::string("Expected SID_PING, got msg id ");
                 error += std::to_string(msg);
                 ctx->setState(new DisconnectedState(error));
+                return;
             }
             DebugPrint("TCP Recv: 0x25 PING\n");
             uint32_t pingValue = ctx->getMsgIStream()->read32();