add functions to run external AI agents

This commit is contained in:
Tim Felgentreff 2020-12-22 21:30:11 +01:00
parent a85e31932c
commit 09384d051c
2 changed files with 171 additions and 0 deletions

View file

@ -34,6 +34,10 @@
-- Includes
----------------------------------------------------------------------------*/
#include "network.h"
#include "network/netsockets.h"
#include "net_lowlevel.h"
#include "results.h"
#include "stratagus.h"
#include "ai.h"
@ -1415,6 +1419,114 @@ static int CclDefineAiPlayer(lua_State *l)
return 0;
}
/**
* AiProcessorSetup(host, port, number_of_state_variables, number_of_actions)
*
* Connect to an AI agent running at host:port, that will consume
* number_of_state_variables every step and select one of number_of_actions.
*/
static int CclAiProcessorSetup(lua_State *l)
{
InitNetwork1();
LuaCheckArgs(l, 4);
std::string host = LuaToString(l, 1);
int port = LuaToNumber(l, 2);
int stateDim = LuaToNumber(l, 3);
int actionDim = LuaToNumber(l, 4);
CHost h(host.c_str(), port);
CTCPSocket *s = new CTCPSocket();
s->Open(CHost());
if (s->Connect(h)) {
char buf[3];
buf[0] = 'I';
buf[1] = (uint8_t)stateDim;
buf[2] = (uint8_t)actionDim;
s->Send(buf, 3);
lua_pushlightuserdata(l, s);
return 1;
}
delete s;
lua_pushnil(l);
return 1;
}
/**
* AiProcessorStep(handle, reward_since_last_call, table_of_state_variables)
*/
static int CclAiProcessorStep(lua_State *l)
{
// A single step in a reinforcement learning network
// We receive the current env and current reward in the arguments
// We need to return the next action.
// The next call to this function will be the updated state, reward for the
// last action
LuaCheckArgs(l, 3);
CTCPSocket *s = (CTCPSocket *)lua_touserdata(l, 1);
if (s == NULL) {
LuaError(l, "first argument must be valid handle returned from a previous AiProcessorSetup call");
}
uint32_t reward = htonl(LuaToNumber(l, 2));
char buf[5];
buf[0] = 'R';
memcpy(buf, &reward, sizeof(uint32_t));
s->Send(buf, 5);
if (!lua_istable(l, 3)) {
LuaError(l, "3rd argument to AiProcessorStep must be table");
}
char stepBuf[1025] = {'\0'}; // room for 256 variables
stepBuf[0] = 'S';
int i = 1;
for (lua_pushnil(l); lua_next(l, 3); lua_pop(l, 1)) {
// idx is ignored
uint32_t var = htonl(LuaToNumber(l, -1));
memcpy(stepBuf + i, &var, sizeof(uint32_t));
i += sizeof(uint32_t);
if (i + sizeof(uint32_t) > 1025) {
LuaError(l, "too many state variables");
}
}
s->Send(stepBuf, i);
int action = 0;
s->Recv(&action, 1);
lua_pushnumber(l, action);
return 1;
}
static int CclAiProcessorClose(lua_State *l)
{
LuaCheckArgs(l, 2);
CTCPSocket *s = (CTCPSocket *)lua_touserdata(l, 1);
if (s == NULL) {
LuaError(l, "first argument must be valid handle returned from a previous AiProcessorSetup call");
}
int gameresult = LuaToNumber(l, 2);
switch (gameresult) {
case GameVictory:
s->Send("E\2", 2);
break;
case GameDefeat:
s->Send("E\1", 2);
break;
default:
s->Send("E\0", 2);
break;
}
s->Close();
delete s;
return 0;
}
/**
** Register CCL features for unit-type.
*/
@ -1458,6 +1570,11 @@ void AiCclRegister()
lua_register(Lua, "DefineAiPlayer", CclDefineAiPlayer);
lua_register(Lua, "AiAttackWithForces", CclAiAttackWithForces);
lua_register(Lua, "AiWaitForces", CclAiWaitForces);
// for external AI processors
lua_register(Lua, "AiProcessorSetup", CclAiProcessorSetup);
lua_register(Lua, "AiProcessorStep", CclAiProcessorStep);
lua_register(Lua, "AiProcessorClose", CclAiProcessorClose);
}
//@}

54
test_ai.py Normal file
View file

@ -0,0 +1,54 @@
import socket
import struct
import random
if __name__ == "__main__":
# Listen for incoming datagrams
localIP = "127.0.0.1"
localPort = 9292
buffersize = 1024
sock = socket.socket(family=socket.AF_INET)
sock.bind((localIP, localPort))
sock.listen(5)
while True:
print("TCP server up and listening on", localIP, localPort)
(clientsocket, address) = sock.accept()
print("connection", address)
num_state = 0
num_actions = 0
act = 0
long_size = struct.calcsize('!l')
while(True):
command = clientsocket.recv(1)
if not command:
break
print(command)
if command == b"I":
num_state = ord(clientsocket.recv(1))
num_actions = ord(clientsocket.recv(1))
state_unpack_fmt = "!" + "l" * num_state
print("setup", num_state, num_actions)
elif command == b"R":
r = b""
while len(r) < long_size:
r += clientsocket.recv(long_size - len(r))
reward = struct.unpack("!l", r)[0]
print("reward", reward)
elif command == b"S":
r = b""
expected = long_size * num_state
while len(r) < expected:
r += clientsocket.recv(expected - len(r))
args = struct.unpack(state_unpack_fmt, r)
print("step", args)
# act = random.choice(range(num_actions))
act = act % num_actions
print("action", act)
clientsocket.sendall(bytearray([act]))
act += 1
elif command == b"E":
e = ord(clientsocket.recv(1))
print("end", e)