From 09384d051c5e8d4a359afc5e6f8977c8cf5a5632 Mon Sep 17 00:00:00 2001 From: Tim Felgentreff <timfelgentreff@gmail.com> Date: Tue, 22 Dec 2020 21:30:11 +0100 Subject: [PATCH] add functions to run external AI agents --- src/ai/script_ai.cpp | 117 +++++++++++++++++++++++++++++++++++++++++++ test_ai.py | 54 ++++++++++++++++++++ 2 files changed, 171 insertions(+) create mode 100644 test_ai.py diff --git a/src/ai/script_ai.cpp b/src/ai/script_ai.cpp index 23f8e4243..7ddf2fcbf 100644 --- a/src/ai/script_ai.cpp +++ b/src/ai/script_ai.cpp @@ -34,6 +34,10 @@ -- Includes ----------------------------------------------------------------------------*/ +#include "network.h" +#include "network/netsockets.h" +#include "net_lowlevel.h" +#include "results.h" #include "stratagus.h" #include "ai.h" @@ -1415,6 +1419,114 @@ static int CclDefineAiPlayer(lua_State *l) return 0; } +/** + * AiProcessorSetup(host, port, number_of_state_variables, number_of_actions) + * + * Connect to an AI agent running at host:port, that will consume + * number_of_state_variables every step and select one of number_of_actions. + */ +static int CclAiProcessorSetup(lua_State *l) +{ + InitNetwork1(); + LuaCheckArgs(l, 4); + std::string host = LuaToString(l, 1); + int port = LuaToNumber(l, 2); + int stateDim = LuaToNumber(l, 3); + int actionDim = LuaToNumber(l, 4); + + CHost h(host.c_str(), port); + CTCPSocket *s = new CTCPSocket(); + s->Open(CHost()); + if (s->Connect(h)) { + char buf[3]; + buf[0] = 'I'; + buf[1] = (uint8_t)stateDim; + buf[2] = (uint8_t)actionDim; + s->Send(buf, 3); + lua_pushlightuserdata(l, s); + return 1; + } + + delete s; + lua_pushnil(l); + return 1; +} + +/** + * AiProcessorStep(handle, reward_since_last_call, table_of_state_variables) + */ +static int CclAiProcessorStep(lua_State *l) +{ + // A single step in a reinforcement learning network + + // We receive the current env and current reward in the arguments + + // We need to return the next action. + + // The next call to this function will be the updated state, reward for the + // last action + + LuaCheckArgs(l, 3); + CTCPSocket *s = (CTCPSocket *)lua_touserdata(l, 1); + if (s == NULL) { + LuaError(l, "first argument must be valid handle returned from a previous AiProcessorSetup call"); + } + + uint32_t reward = htonl(LuaToNumber(l, 2)); + char buf[5]; + buf[0] = 'R'; + memcpy(buf, &reward, sizeof(uint32_t)); + s->Send(buf, 5); + + if (!lua_istable(l, 3)) { + LuaError(l, "3rd argument to AiProcessorStep must be table"); + } + + char stepBuf[1025] = {'\0'}; // room for 256 variables + stepBuf[0] = 'S'; + int i = 1; + for (lua_pushnil(l); lua_next(l, 3); lua_pop(l, 1)) { + // idx is ignored + uint32_t var = htonl(LuaToNumber(l, -1)); + memcpy(stepBuf + i, &var, sizeof(uint32_t)); + i += sizeof(uint32_t); + if (i + sizeof(uint32_t) > 1025) { + LuaError(l, "too many state variables"); + } + } + s->Send(stepBuf, i); + + int action = 0; + s->Recv(&action, 1); + lua_pushnumber(l, action); + return 1; +} + +static int CclAiProcessorClose(lua_State *l) +{ + LuaCheckArgs(l, 2); + CTCPSocket *s = (CTCPSocket *)lua_touserdata(l, 1); + if (s == NULL) { + LuaError(l, "first argument must be valid handle returned from a previous AiProcessorSetup call"); + } + + int gameresult = LuaToNumber(l, 2); + switch (gameresult) { + case GameVictory: + s->Send("E\2", 2); + break; + case GameDefeat: + s->Send("E\1", 2); + break; + default: + s->Send("E\0", 2); + break; + } + s->Close(); + delete s; + return 0; +} + /** ** Register CCL features for unit-type. */ @@ -1458,6 +1570,11 @@ void AiCclRegister() lua_register(Lua, "DefineAiPlayer", CclDefineAiPlayer); lua_register(Lua, "AiAttackWithForces", CclAiAttackWithForces); lua_register(Lua, "AiWaitForces", CclAiWaitForces); + + // for external AI processors + lua_register(Lua, "AiProcessorSetup", CclAiProcessorSetup); + lua_register(Lua, "AiProcessorStep", CclAiProcessorStep); + lua_register(Lua, "AiProcessorClose", CclAiProcessorClose); } //@} diff --git a/test_ai.py b/test_ai.py new file mode 100644 index 000000000..81a6919be --- /dev/null +++ b/test_ai.py @@ -0,0 +1,54 @@ +import socket +import struct +import random + + +if __name__ == "__main__": + # Listen for incoming datagrams + localIP = "127.0.0.1" + localPort = 9292 + buffersize = 1024 + sock = socket.socket(family=socket.AF_INET) + sock.bind((localIP, localPort)) + sock.listen(5) + + while True: + print("TCP server up and listening on", localIP, localPort) + (clientsocket, address) = sock.accept() + print("connection", address) + num_state = 0 + num_actions = 0 + act = 0 + long_size = struct.calcsize('!l') + + while(True): + command = clientsocket.recv(1) + if not command: + break + print(command) + if command == b"I": + num_state = ord(clientsocket.recv(1)) + num_actions = ord(clientsocket.recv(1)) + state_unpack_fmt = "!" + "l" * num_state + print("setup", num_state, num_actions) + elif command == b"R": + r = b"" + while len(r) < long_size: + r += clientsocket.recv(long_size - len(r)) + reward = struct.unpack("!l", r)[0] + print("reward", reward) + elif command == b"S": + r = b"" + expected = long_size * num_state + while len(r) < expected: + r += clientsocket.recv(expected - len(r)) + args = struct.unpack(state_unpack_fmt, r) + print("step", args) + # act = random.choice(range(num_actions)) + act = act % num_actions + print("action", act) + clientsocket.sendall(bytearray([act])) + act += 1 + elif command == b"E": + e = ord(clientsocket.recv(1)) + print("end", e)