add functions to run external AI agents
This commit is contained in:
parent
a85e31932c
commit
09384d051c
2 changed files with 171 additions and 0 deletions
|
@ -34,6 +34,10 @@
|
|||
-- Includes
|
||||
----------------------------------------------------------------------------*/
|
||||
|
||||
#include "network.h"
|
||||
#include "network/netsockets.h"
|
||||
#include "net_lowlevel.h"
|
||||
#include "results.h"
|
||||
#include "stratagus.h"
|
||||
|
||||
#include "ai.h"
|
||||
|
@ -1415,6 +1419,114 @@ static int CclDefineAiPlayer(lua_State *l)
|
|||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* AiProcessorSetup(host, port, number_of_state_variables, number_of_actions)
|
||||
*
|
||||
* Connect to an AI agent running at host:port, that will consume
|
||||
* number_of_state_variables every step and select one of number_of_actions.
|
||||
*/
|
||||
static int CclAiProcessorSetup(lua_State *l)
|
||||
{
|
||||
InitNetwork1();
|
||||
LuaCheckArgs(l, 4);
|
||||
std::string host = LuaToString(l, 1);
|
||||
int port = LuaToNumber(l, 2);
|
||||
int stateDim = LuaToNumber(l, 3);
|
||||
int actionDim = LuaToNumber(l, 4);
|
||||
|
||||
CHost h(host.c_str(), port);
|
||||
CTCPSocket *s = new CTCPSocket();
|
||||
s->Open(CHost());
|
||||
if (s->Connect(h)) {
|
||||
char buf[3];
|
||||
buf[0] = 'I';
|
||||
buf[1] = (uint8_t)stateDim;
|
||||
buf[2] = (uint8_t)actionDim;
|
||||
s->Send(buf, 3);
|
||||
lua_pushlightuserdata(l, s);
|
||||
return 1;
|
||||
}
|
||||
|
||||
delete s;
|
||||
lua_pushnil(l);
|
||||
return 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* AiProcessorStep(handle, reward_since_last_call, table_of_state_variables)
|
||||
*/
|
||||
static int CclAiProcessorStep(lua_State *l)
|
||||
{
|
||||
// A single step in a reinforcement learning network
|
||||
|
||||
// We receive the current env and current reward in the arguments
|
||||
|
||||
// We need to return the next action.
|
||||
|
||||
// The next call to this function will be the updated state, reward for the
|
||||
// last action
|
||||
|
||||
LuaCheckArgs(l, 3);
|
||||
CTCPSocket *s = (CTCPSocket *)lua_touserdata(l, 1);
|
||||
if (s == NULL) {
|
||||
LuaError(l, "first argument must be valid handle returned from a previous AiProcessorSetup call");
|
||||
}
|
||||
|
||||
uint32_t reward = htonl(LuaToNumber(l, 2));
|
||||
char buf[5];
|
||||
buf[0] = 'R';
|
||||
memcpy(buf, &reward, sizeof(uint32_t));
|
||||
s->Send(buf, 5);
|
||||
|
||||
if (!lua_istable(l, 3)) {
|
||||
LuaError(l, "3rd argument to AiProcessorStep must be table");
|
||||
}
|
||||
|
||||
char stepBuf[1025] = {'\0'}; // room for 256 variables
|
||||
stepBuf[0] = 'S';
|
||||
int i = 1;
|
||||
for (lua_pushnil(l); lua_next(l, 3); lua_pop(l, 1)) {
|
||||
// idx is ignored
|
||||
uint32_t var = htonl(LuaToNumber(l, -1));
|
||||
memcpy(stepBuf + i, &var, sizeof(uint32_t));
|
||||
i += sizeof(uint32_t);
|
||||
if (i + sizeof(uint32_t) > 1025) {
|
||||
LuaError(l, "too many state variables");
|
||||
}
|
||||
}
|
||||
s->Send(stepBuf, i);
|
||||
|
||||
int action = 0;
|
||||
s->Recv(&action, 1);
|
||||
lua_pushnumber(l, action);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int CclAiProcessorClose(lua_State *l)
|
||||
{
|
||||
LuaCheckArgs(l, 2);
|
||||
CTCPSocket *s = (CTCPSocket *)lua_touserdata(l, 1);
|
||||
if (s == NULL) {
|
||||
LuaError(l, "first argument must be valid handle returned from a previous AiProcessorSetup call");
|
||||
}
|
||||
|
||||
int gameresult = LuaToNumber(l, 2);
|
||||
switch (gameresult) {
|
||||
case GameVictory:
|
||||
s->Send("E\2", 2);
|
||||
break;
|
||||
case GameDefeat:
|
||||
s->Send("E\1", 2);
|
||||
break;
|
||||
default:
|
||||
s->Send("E\0", 2);
|
||||
break;
|
||||
}
|
||||
s->Close();
|
||||
delete s;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
** Register CCL features for unit-type.
|
||||
*/
|
||||
|
@ -1458,6 +1570,11 @@ void AiCclRegister()
|
|||
lua_register(Lua, "DefineAiPlayer", CclDefineAiPlayer);
|
||||
lua_register(Lua, "AiAttackWithForces", CclAiAttackWithForces);
|
||||
lua_register(Lua, "AiWaitForces", CclAiWaitForces);
|
||||
|
||||
// for external AI processors
|
||||
lua_register(Lua, "AiProcessorSetup", CclAiProcessorSetup);
|
||||
lua_register(Lua, "AiProcessorStep", CclAiProcessorStep);
|
||||
lua_register(Lua, "AiProcessorClose", CclAiProcessorClose);
|
||||
}
|
||||
|
||||
//@}
|
||||
|
|
54
test_ai.py
Normal file
54
test_ai.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
import socket
|
||||
import struct
|
||||
import random
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Listen for incoming datagrams
|
||||
localIP = "127.0.0.1"
|
||||
localPort = 9292
|
||||
buffersize = 1024
|
||||
sock = socket.socket(family=socket.AF_INET)
|
||||
sock.bind((localIP, localPort))
|
||||
sock.listen(5)
|
||||
|
||||
while True:
|
||||
print("TCP server up and listening on", localIP, localPort)
|
||||
(clientsocket, address) = sock.accept()
|
||||
print("connection", address)
|
||||
num_state = 0
|
||||
num_actions = 0
|
||||
act = 0
|
||||
long_size = struct.calcsize('!l')
|
||||
|
||||
while(True):
|
||||
command = clientsocket.recv(1)
|
||||
if not command:
|
||||
break
|
||||
print(command)
|
||||
if command == b"I":
|
||||
num_state = ord(clientsocket.recv(1))
|
||||
num_actions = ord(clientsocket.recv(1))
|
||||
state_unpack_fmt = "!" + "l" * num_state
|
||||
print("setup", num_state, num_actions)
|
||||
elif command == b"R":
|
||||
r = b""
|
||||
while len(r) < long_size:
|
||||
r += clientsocket.recv(long_size - len(r))
|
||||
reward = struct.unpack("!l", r)[0]
|
||||
print("reward", reward)
|
||||
elif command == b"S":
|
||||
r = b""
|
||||
expected = long_size * num_state
|
||||
while len(r) < expected:
|
||||
r += clientsocket.recv(expected - len(r))
|
||||
args = struct.unpack(state_unpack_fmt, r)
|
||||
print("step", args)
|
||||
# act = random.choice(range(num_actions))
|
||||
act = act % num_actions
|
||||
print("action", act)
|
||||
clientsocket.sendall(bytearray([act]))
|
||||
act += 1
|
||||
elif command == b"E":
|
||||
e = ord(clientsocket.recv(1))
|
||||
print("end", e)
|
Loading…
Add table
Reference in a new issue