From 09384d051c5e8d4a359afc5e6f8977c8cf5a5632 Mon Sep 17 00:00:00 2001
From: Tim Felgentreff <timfelgentreff@gmail.com>
Date: Tue, 22 Dec 2020 21:30:11 +0100
Subject: [PATCH] add functions to run external AI agents

---
 src/ai/script_ai.cpp | 117 +++++++++++++++++++++++++++++++++++++++++++
 test_ai.py           |  54 ++++++++++++++++++++
 2 files changed, 171 insertions(+)
 create mode 100644 test_ai.py

diff --git a/src/ai/script_ai.cpp b/src/ai/script_ai.cpp
index 23f8e4243..7ddf2fcbf 100644
--- a/src/ai/script_ai.cpp
+++ b/src/ai/script_ai.cpp
@@ -34,6 +34,10 @@
 --  Includes
 ----------------------------------------------------------------------------*/
 
+#include "network.h"
+#include "network/netsockets.h"
+#include "net_lowlevel.h"
+#include "results.h"
 #include "stratagus.h"
 
 #include "ai.h"
@@ -1415,6 +1419,114 @@ static int CclDefineAiPlayer(lua_State *l)
 	return 0;
 }
 
+/**
+ * AiProcessorSetup(host, port, number_of_state_variables, number_of_actions)
+ *
+ * Connect to an AI agent running at host:port, that will consume
+ * number_of_state_variables every step and select one of number_of_actions.
+ */
+static int CclAiProcessorSetup(lua_State *l)
+{
+	InitNetwork1();
+	LuaCheckArgs(l, 4);
+	std::string host = LuaToString(l, 1);
+	int port = LuaToNumber(l, 2);
+	int stateDim = LuaToNumber(l, 3);
+	int actionDim = LuaToNumber(l, 4);
+
+	CHost h(host.c_str(), port);
+	CTCPSocket *s = new CTCPSocket();
+	s->Open(CHost());
+	if (s->Connect(h)) {
+		char buf[3];
+		buf[0] = 'I';
+		buf[1] = (uint8_t)stateDim;
+		buf[2] = (uint8_t)actionDim;
+		s->Send(buf, 3);
+		lua_pushlightuserdata(l, s);
+		return 1;
+	}
+
+	delete s;
+	lua_pushnil(l);
+	return 1;
+}
+
+/**
+ * AiProcessorStep(handle, reward_since_last_call, table_of_state_variables)
+ */
+static int CclAiProcessorStep(lua_State *l)
+{
+	// A single step in a reinforcement learning network
+
+	// We receive the current env and current reward in the arguments
+
+	// We need to return the next action.
+
+	// The next call to this function will be the updated state, reward for the
+	// last action
+
+	LuaCheckArgs(l, 3);
+	CTCPSocket *s = (CTCPSocket *)lua_touserdata(l, 1);
+	if (s == NULL) {
+		LuaError(l, "first argument must be valid handle returned from a previous AiProcessorSetup call");
+	}
+
+	uint32_t reward = htonl(LuaToNumber(l, 2));
+	char buf[5];
+	buf[0] = 'R';
+	memcpy(buf, &reward, sizeof(uint32_t));
+	s->Send(buf, 5);
+
+	if (!lua_istable(l, 3)) {
+		LuaError(l, "3rd argument to AiProcessorStep must be table");
+	}
+
+	char stepBuf[1025] = {'\0'}; // room for 256 variables
+	stepBuf[0] = 'S';
+	int i = 1;
+	for (lua_pushnil(l); lua_next(l, 3); lua_pop(l, 1)) {
+		// idx is ignored
+		uint32_t var = htonl(LuaToNumber(l, -1));
+		memcpy(stepBuf + i, &var, sizeof(uint32_t));
+		i += sizeof(uint32_t);
+		if (i + sizeof(uint32_t) > 1025) {
+			LuaError(l, "too many state variables");
+		}
+	}
+	s->Send(stepBuf, i);
+
+	int action = 0;
+	s->Recv(&action, 1);
+	lua_pushnumber(l, action);
+	return 1;
+}
+
+static int CclAiProcessorClose(lua_State *l)
+{
+	LuaCheckArgs(l, 2);
+	CTCPSocket *s = (CTCPSocket *)lua_touserdata(l, 1);
+	if (s == NULL) {
+		LuaError(l, "first argument must be valid handle returned from a previous AiProcessorSetup call");
+	}
+
+	int gameresult = LuaToNumber(l, 2);
+	switch (gameresult) {
+	case GameVictory:
+		s->Send("E\2", 2);
+		break;
+	case GameDefeat:
+		s->Send("E\1", 2);
+		break;
+	default:
+		s->Send("E\0", 2);
+		break;
+	}
+	s->Close();
+	delete s;
+	return 0;
+}
+
 /**
 **  Register CCL features for unit-type.
 */
@@ -1458,6 +1570,11 @@ void AiCclRegister()
 	lua_register(Lua, "DefineAiPlayer", CclDefineAiPlayer);
 	lua_register(Lua, "AiAttackWithForces", CclAiAttackWithForces);
 	lua_register(Lua, "AiWaitForces", CclAiWaitForces);
+
+	// for external AI processors
+	lua_register(Lua, "AiProcessorSetup", CclAiProcessorSetup);
+	lua_register(Lua, "AiProcessorStep", CclAiProcessorStep);
+	lua_register(Lua, "AiProcessorClose", CclAiProcessorClose);
 }
 
 //@}
diff --git a/test_ai.py b/test_ai.py
new file mode 100644
index 000000000..81a6919be
--- /dev/null
+++ b/test_ai.py
@@ -0,0 +1,54 @@
+import socket
+import struct
+import random
+
+
+if __name__ == "__main__":
+    # Listen for incoming datagrams
+    localIP = "127.0.0.1"
+    localPort = 9292
+    buffersize = 1024
+    sock = socket.socket(family=socket.AF_INET)
+    sock.bind((localIP, localPort))
+    sock.listen(5)
+
+    while True:
+        print("TCP server up and listening on", localIP, localPort)
+        (clientsocket, address) = sock.accept()
+        print("connection", address)
+        num_state = 0
+        num_actions = 0
+        act = 0
+        long_size = struct.calcsize('!l')
+
+        while(True):
+            command = clientsocket.recv(1)
+            if not command:
+                break
+            print(command)
+            if command == b"I":
+                num_state = ord(clientsocket.recv(1))
+                num_actions = ord(clientsocket.recv(1))
+                state_unpack_fmt = "!" + "l" * num_state
+                print("setup", num_state, num_actions)
+            elif command == b"R":
+                r = b""
+                while len(r) < long_size:
+                    r += clientsocket.recv(long_size - len(r))
+                reward = struct.unpack("!l", r)[0]
+                print("reward", reward)
+            elif command == b"S":
+                r = b""
+                expected = long_size * num_state
+                while len(r) < expected:
+                    r += clientsocket.recv(expected - len(r))
+                args = struct.unpack(state_unpack_fmt, r)
+                print("step", args)
+                # act = random.choice(range(num_actions))
+                act = act % num_actions
+                print("action", act)
+                clientsocket.sendall(bytearray([act]))
+                act += 1
+            elif command == b"E":
+                e = ord(clientsocket.recv(1))
+                print("end", e)