fix various conditions in which the online service message buffer would become corrupted

This commit is contained in:
Tim Felgentreff 2020-12-13 11:41:28 +01:00
parent 7e20aab0b7
commit d0ed8f9874

View file

@ -97,8 +97,8 @@ public:
~BNCSInputStream() {};
std::string readString() {
if (received_bytes == 0) {
return NULL;
if (received_bytes - pos <= 0) {
return "";
}
std::stringstream strstr;
int i = pos;
@ -124,6 +124,19 @@ public:
return stringlist;
};
std::vector<std::string> readStringlist(int cnt) {
std::vector<std::string> stringlist;
for (; cnt >= 0; cnt--) {
std::string nxt = readString();
if (nxt.empty()) {
break;
} else {
stringlist.push_back(nxt);
}
}
return stringlist;
};
uint8_t read8() {
uint8_t byte = buffer[pos];
consumeData(1);
@ -188,12 +201,24 @@ public:
// (UINT8) Message ID
// (UINT16) Message length, including this header
// (VOID) Message data
received_bytes += this->sock->Recv(buffer + received_bytes, 4 - received_bytes);
if (received_bytes != 4) {
if (received_bytes < 4) {
// in case of retry, we may already have the first 4 bytes
received_bytes += this->sock->Recv(buffer + received_bytes, 4 - received_bytes);
}
if (received_bytes < 4) {
// didn't get the complete header yet
return -1;
}
assert(pos == 0);
uint8_t headerbyte = read8();
if (headerbyte != 0xff) {
// Likely a bug on our side. We just skip this byte.
debugDump();
memmove(buffer, buffer + 1, received_bytes - 1);
received_bytes -= 1;
pos = 0;
return -1;
}
assert(received_bytes == 4);
assert(read8() == 0xff);
uint8_t msgId = read8();
uint16_t len = read16();
// we still need to have len in total for this message, so if we have
@ -201,9 +226,9 @@ public:
// first 4 bytes that we already consumed, we'll have enough
assert(pos == 4);
long needed = len - received_bytes;
if (needed != 0) {
assert(needed > 0);
if (len >= bufsize) {
if (needed > 0) {
if (needed >= bufsize + received_bytes) {
// we never shrink the buffer again while we're online
buffer = (char*)realloc(buffer, sizeof(char) * len + 1);
bufsize = len + 1;
}
@ -218,13 +243,27 @@ public:
return msgId;
};
void debugDump() {
if (EnableDebugPrint) {
std::cout << "Input stream state: pos(" << pos << "), received_bytes(" << received_bytes << ")" << std::endl;
dump((uint8_t*)buffer, received_bytes);
}
}
void finishMessage() {
received_bytes = 0;
assert(pos <= received_bytes);
received_bytes = received_bytes - pos;
if (received_bytes > 0) {
// move the remaining received bytes to the start of the buffer, to
// be used for the next message
memmove(buffer, buffer + pos, received_bytes);
}
pos = 0;
}
private:
void consumeData(int bytes) {
assert(pos + bytes <= received_bytes);
pos += bytes;
}
@ -1470,7 +1509,7 @@ private:
assert(cnt == 1);
uint32_t keys = ctx->getMsgIStream()->read32();
uint32_t reqId = ctx->getMsgIStream()->read32();
std::vector<std::string> values = ctx->getMsgIStream()->readStringlist();
std::vector<std::string> values = ctx->getMsgIStream()->readStringlist(keys);
ctx->reportUserdata(reqId, values);
}
@ -1575,6 +1614,7 @@ class S2C_ENTERCHAT : public NetworkState {
std::string error = std::string("Expected SID_ENTERCHAT, got msg id ");
error += std::to_string(msg);
ctx->setState(new DisconnectedState(error));
return;
}
DebugPrint("TCP Recv: 0x0a ENTERCHAT\n");
@ -1636,6 +1676,7 @@ class S2C_CREATEACCOUNT2 : public NetworkState {
std::string error = std::string("Expected SID_CREATEACCOUNT2, got msg id ");
error += std::to_string(msg);
ctx->setState(new DisconnectedState(error));
return;
}
DebugPrint("TCP Recv: 0x3d CREATEACCOUNT\n");
@ -1714,7 +1755,10 @@ class S2C_LOGONRESPONSE2 : public NetworkState {
if (msg != 0x3a) {
std::string error = std::string("Expected SID_LOGONRESPONSE2, got msg id ");
error += std::to_string(msg);
ctx->setState(new DisconnectedState(error));
ctx->showError(error);
ctx->getMsgIStream()->debugDump();
ctx->getMsgIStream()->finishMessage();
return;
}
DebugPrint("TCP Sent: 0x3a LOGONRESPONSE\n");
@ -1747,6 +1791,7 @@ class S2C_LOGONRESPONSE2 : public NetworkState {
return;
default:
ctx->setState(new DisconnectedState("unknown logon response"));
return;
}
}
}
@ -1816,6 +1861,7 @@ class S2C_SID_AUTH_CHECK : public NetworkState {
std::string error = std::string("Expected SID_AUTH_CHECK, got msg id ");
error += std::to_string(msg);
ctx->setState(new DisconnectedState(error));
return;
}
DebugPrint("TCP Recv: 0x51 AUTH_CHECK\n");
@ -1877,6 +1923,7 @@ class S2C_SID_AUTH_INFO : public NetworkState {
std::string error = std::string("Expected SID_AUTH_INFO, got msg id ");
error += std::to_string(msg);
ctx->setState(new DisconnectedState(error));
return;
}
DebugPrint("TCP Recv: 0x50 AUTH_INFO\n");
@ -1941,6 +1988,7 @@ class S2C_SID_PING : public NetworkState {
std::string error = std::string("Expected SID_PING, got msg id ");
error += std::to_string(msg);
ctx->setState(new DisconnectedState(error));
return;
}
DebugPrint("TCP Recv: 0x25 PING\n");
uint32_t pingValue = ctx->getMsgIStream()->read32();