Source code for ailiga.battle

import argparse
import multiprocessing as mp

import numpy as np
import tqdm
from tianshou.data import Collector
from tianshou.env import DummyVectorEnv, RayVectorEnv, SubprocVectorEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy

from ailiga import env as menv
from ailiga.all_fighters import (
    get_all_fighters,
    get_fighter_by_name,
    get_fighters_from_list,
)


[docs]class Battle: """Runs a battle between two or more agents."""
[docs] def __init__(self, lambda_env, agents): self.env = lambda_env() self.lambda_env = lambda_env if isinstance(agents[0], type): # agents are classes, not instances self.agents = [a(self.lambda_env) for a in agents] else: self.agents = agents self.policies = [a.get_policy() for a in self.agents] self.env.reset() self.rews = None self.lens = None if len(self.env.agents) != len(self.agents): raise ValueError( "Agents do not match environment: " + str(self.env.agents) + " vs " + str(self.agents) )
[docs] def fight(self, n_episodes=1, n_step=None, render=None, n_jobs=None): """ Runs a number of episodes between two agents. :param n_episodes: number of episodes to run :param n_step: number of steps per episode :param render: if True, render the environment :return: list of rewards """ env = self.env policy = MultiAgentPolicyManager(self.policies, self.env) policy.eval() # policy.policies[agents[args.agent_id - 1]].set_eps(0.05) collector = Collector( policy, # DummyVectorEnv([lambda: env for _ in range(1)]), # SubprocVectorEnv([lambda: env for _ in range(10)]), SubprocVectorEnv( [ lambda: env for _ in range(n_jobs if n_jobs is not None else mp.cpu_count()) ] ), exploration_noise=True, ) result = collector.collect(n_episode=n_episodes, n_step=n_step, render=render) self.rews, self.lens = result["rews"], result["lens"] return [self.rews[:, i].mean() for i in range(len(self.agents))]
[docs]def battle( a_fighter=None, a_env="tictactoe_v3", a_n_episodes=10000, a_n_steps=None, render=False, n_jobs=None, a_force=False, ): """Run a battle between agents.""" e = a_env fghts = get_fighters_from_list(a_fighter) if not a_force: # get all fighters that are valid for the given env fghts = [a for a in fghts if a.valid_env(e)] b = Battle(menv.get_env(e), fghts) res = b.fight( n_episodes=a_n_episodes, n_step=a_n_steps, render=render, n_jobs=n_jobs ) print("Env:", e) print("Fighters:", [a.get_name() for a in fghts]) print("Rewards:", res)
[docs]def main(): parser = argparse.ArgumentParser(description="Run a tournament between agents.") parser.add_argument("--n_episodes", type=int, default=10000) parser.add_argument("--n_step", type=int, default=None) parser.add_argument("--render", type=float, default=None) parser.add_argument("--n_jobs", type=int, default=None) parser.add_argument("--env", type=str, default="tictactoe_v3") parser.add_argument( "--fighter", type=str, nargs="+", default=[], ) parser.add_argument( "-f", "--force", action="store_true", help="force non checked constellations", default=False, ) args = parser.parse_args() battle( args.fighter, args.env, args.n_episodes, args.n_step, args.render, args.n_jobs, args.force, )
if __name__ == "__main__": main()