import numpy as np
import neat
import os

class Field:
    def __init__(self):
        self.state = np.array([])  # I'm declaring these here only because my IDE wants me to
        self.owned = np.array([])

        self.reset()

    def reset(self):
        self.state = np.array([
            [0, 0, 0], [0, 0, 0], [0, 0, 0],
            [0, 0, 0], [0, 0, 0], [0, 0, 0],
            [0, 0, 0], [0, 0, 0], [0, 0, 0]
        ])  # stores the state of all the goals on the field (0 = no ball; 1 = red ball; 2 = blue ball)
        self.owned = np.array([
            0, 0, 0,
            0, 0, 0,
            0, 0, 0
        ])  # stores the owned status of all the goals on the field

    def score_ball(self, goal, color):
        if 0 in self.state[goal]:  # check that the goal is not already full
            self.state[goal][np.where(self.state[goal] == 0)[0][0]] = color  # place a ball in the first available slot
            self.owned[goal] = color  # update the owned status of the goal

    def descore_ball(self, goal):
        self.state[goal][0] = self.state[goal][1]  # set the new bottom slot to the old middle slot
        self.state[goal][1] = self.state[goal][2]  # set the new middle slot to the old top slot
        self.state[goal][2] = 0  # set the top slot to be empty

        if np.all(self.state[goal] == 0):  # check if the goal is now empty
            self.owned[goal] = 0  # update the owned status of the goal

    def get_score(self):
        rows = ((0, 1, 2), (3, 4, 5), (6, 7, 8), (0, 3, 6), (1, 4, 7), (2, 5, 8), (0, 4, 8), (2, 4, 6))  # these represent the indices that make up every row on the field
        score = [0, 0]  # this will store the score in (<red_score>, <blue_score>) format

        for row in rows:  # loop through all the possible rows
            if self.owned[row[0]] == self.owned[row[1]] == self.owned[row[2]] != 0:  # check if the current row is completed
                score[self.owned[row[0]] - 1] += 6  # add six to the score of the color that completed the row

        score[0] += list(self.state.flatten()).count(1)  # add the number of red balls red's score
        score[1] += list(self.state.flatten()).count(2)  # add the number of blue balls to blue's score

        return score

class Player:
    def __init__(self, name, net):
        self.name = name  # just a way of identifying the player
        self.net = net  # assign the player a neural network generated by neat
        self.wins = 0  # keep track of wins for fitness evaluation
        self.scores = []  # also for fitness evaluation

    def move(self, field, color):
        color_descriptor = (color - 1.5) * 2  # a number describing the color (-1 for red, 1 for blue)
        output = self.net.activate(tuple([color_descriptor] + list(field.state.flatten())))  # get the output of the neat network
        action_index = output.index(max(output))  # get the index of the highest number in the list of output values

        if action_index <= 8:  # if action_index is 8 or less, the player should score
            field.score_ball(action_index, color)  # score a ball
        else:  # otherwise, the player should descore
            field.descore_ball(action_index - 9)  # descore a goal

def play_game(field, red_player, blue_player):
    for i in range(10):
        red_player.move(field, 1)  # 1 signifies red
        blue_player.move(field, 2)  # 2 signifies blue

def play_tournament(players, rounds=1):
    field = Field()
    high_score = 0

    for _ in range(rounds):  # I really hate all these for loops; I am looking for a better way to do this
        for i in range(len(players)):
            for j in range(len(players)):
                if i != j:
                    field.reset()  # reset the field
                    play_game(field, players[i], players[j])  # play a game

                    score = field.get_score()  # get the field's score

                    if score[0] > score[1]:  # increment the red player's win count if it won
                        players[i].wins += 2
                    elif score[1] > score[0]:  # increment the blue player's win count if it won
                        players[j].wins += 2
                    else:
                        players[i].wins += 1
                        players[j].wins += 1

                    players[i].scores.append(score[0])
                    players[j].scores.append(score[1])

                    if max(score) > high_score:
                        high_score = max(score)  # this is very important; change it around if you want

    return players, high_score

def eval_genomes(genomes, config):
    players = []
    ge = []  # array of genomes

    for genome_id, genome in genomes:
        genome.fitness = 0

        net = neat.nn.FeedForwardNetwork.create(genome, config)
        player = Player(genome_id, net)

        players.append(player)
        ge.append(genome)

    players, high_score = play_tournament(players)

    for i, player in enumerate(players):
        ge[i].fitness = player.wins

    print("High score:", high_score)

    max_wins = 0

    for player in players:
        if max_wins < player.wins:
            max_wins = player.wins

    print('Max win points:', max_wins)

def run(config_file):
    config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
                         neat.DefaultSpeciesSet, neat.DefaultStagnation,
                         config_file)

    p = neat.Population(config)

    p.add_reporter(neat.StdOutReporter(True))
    p.add_reporter(neat.StatisticsReporter())

    winner = p.run(eval_genomes, 50)  # run for up to 50 generations

if __name__ == '__main__':  # this whole thing is just a complicated way of calling the 'run' function that is kind of unnecessary in this use case
    local_dir = os.path.dirname(__file__)
    config_path = os.path.join(local_dir, 'neat_config.txt')
    run(config_path)

[NEAT]
fitness_criterion     = max
fitness_threshold     = 10000
pop_size              = 20
reset_on_extinction   = False

[DefaultGenome]
# node activation options
activation_default      = tanh
activation_mutate_rate  = 0.0
activation_options      = tanh

# node aggregation options
aggregation_default     = sum
aggregation_mutate_rate = 0.0
aggregation_options     = sum

# node bias options
bias_init_mean          = 0.0
bias_init_stdev         = 1.0
bias_max_value          = 30.0
bias_min_value          = -30.0
bias_mutate_power       = 0.5
bias_mutate_rate        = 0.7
bias_replace_rate       = 0.1

# genome compatibility options
compatibility_disjoint_coefficient = 1.0
compatibility_weight_coefficient   = 0.5

# connection add/remove rates
conn_add_prob           = 0.5
conn_delete_prob        = 0.5

# connection enable options
enabled_default         = True
enabled_mutate_rate     = 0.01

feed_forward            = True
initial_connection      = full

# node add/remove rates
node_add_prob           = 0.2
node_delete_prob        = 0.2

# network parameters
num_hidden              = 0
num_inputs              = 28
num_outputs             = 18

# node response options
response_init_mean      = 1.0
response_init_stdev     = 0.0
response_max_value      = 30.0
response_min_value      = -30.0
response_mutate_power   = 0.0
response_mutate_rate    = 0.0
response_replace_rate   = 0.0

# connection weight options
weight_init_mean        = 0.0
weight_init_stdev       = 1.0
weight_max_value        = 30
weight_min_value        = -30
weight_mutate_power     = 0.5
weight_mutate_rate      = 0.8
weight_replace_rate     = 0.1

[DefaultSpeciesSet]
compatibility_threshold = 3.0

[DefaultStagnation]
species_fitness_func = max
max_stagnation       = 20
species_elitism      = 2

[DefaultReproduction]
elitism            = 2
survival_threshold = 0.2