#!/usr/bin/env python

import random, math
import var, objs, agents, objtext
import time

# Algorithm for processing AI agent DNA:
#   Build object list (random) order
#   For each gene in dna
#     For each object in object list
#       If object matches all bases (conditions) in gene
#         Do one action per tick until finished with actions

# AI data structure constants
types = ['none', 'significant', 'hard', 'damage', 'ship', 'sun',
         'fire', 'spike', 'asteroid', 'shield', 'bullet', 'powerup']

# Action combinations: t=thrust,b=backthrust,l=left,r=right,f=fire
actions = ['...','..f','.l.','.lf','.r.','.rf',
           't..','t.f','tl.','tlf','tr.','trf',
           'b..','b.f','bl.','blf','br.','brf']
comparisons = [
        'none',
        'rand_num',   # <rand> operator value
        'dist_dist',  # <dist * future1> operator <dist * future2 + value>
        'dist_num',   # <dist * future1> operator value
        'dir_dir',    # <dir * future1> operator <dir * future2 + value>
        'dir_num']    # <dir * future1> operator value

# future is index into futures array to get frames into the 
# future for distance and direction
futures = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
           10,11,12,13,14,15,16,17,18,19,
           20,22,24,26,28,30,32,34,36,38,
           40,45,50,55,60,65,70,75,80,85]

# Meaning of value is either raw or index into distances 
# or directions array depending on setting of comparisons
distances = [ 0,  10, 12, 14, 16,
              18, 20, 22, 24, 26,
              28, 30, 35, 40, 45,
              50, 55, 60, 65, 70,
              75, 80, 90,100,110,
             120,140,160,180,200,
             220,250,300,350,400,
             450,500,600,700,800]

directions = range(21) + range(-19,0)

# for distance, == means between previous and next
operators = ['==', '>', '<']

# Upper half of gauss distribution
def half_gauss(start, end):
    found = 0
    while not found:
        num = random.gauss(start, (end-start)/3.0)
        if num >= start and num <= end:
            found = 1
    return num

# Temporary mass objects for performing calculations
class Temp(objs.Mass):
    def __init__(self):
        x = 0
        y = 0
        dir = 0
t1 = Temp()
t2 = Temp()

# Each base represents a condition with 5 components
class Base:
    def __init__(self,c='none',f1=0,f2=0,v=0,op="=="):
        self.comparison = c
        self.future1 = f1
        self.future2 = f2
        self.value = v
        self.operator = op

    def random(self):
        self.comparison = random.choice(comparisons)
        self.future1 = int(half_gauss(0,40))
        self.future2 = int(half_gauss(0,40))
        self.value = int(half_gauss(0,40))
        self.operator = random.choice(operators)

    def test(self, ship, object):
        right = var.arena.right
        bottom = var.arena.bottom
        f1 = futures[self.future1]
        f2 = futures[self.future2]
        ret = 0
        if self.comparison == 'none':
            ret = 1
        elif self.comparison == 'rand_num':
            rand = random.randint(0,39)
            ret = 0
            if (self.operator == "<")  and (rand < self.value):  ret = 1
            if (self.operator == ">")  and (rand > self.value):  ret = 1
            if (self.operator == "==") and (rand == self.value): ret = 1
        elif self.comparison == 'dist_dist':
            t1.x = (ship.x + ship.vx * f1) % right
            t1.y = (ship.y + ship.vy * f1) % bottom
            t2.x = (object.x + object.vx * f1) % right
            t2.y = (object.y + object.vy * f1) % bottom
            dist1 = float(objs.Mass.distance(t1, t2))
            t1.x = (ship.x + ship.vx * f2) % right
            t1.y = (ship.y + ship.vy * f2) % bottom
            t2.x = (object.x + object.vx * f2) % right
            t2.y = (object.y + object.vy * f2) % bottom
            dist2 = float(objs.Mass.distance(t1, t2) + distances[self.value])
            if (self.operator == "<")  and (dist1 < dist2): ret = 1
            if (self.operator == ">")  and (dist1 > dist2): ret = 1
            if (self.operator == "=="):
                if abs((dist1 - dist2)/((dist1+dist2)/2)) < 0.2: ret = 1
        elif self.comparison == 'dist_num':
            t1.x = (ship.x + ship.vx * f1) % right
            t1.y = (ship.y + ship.vy * f1) % bottom
            t2.x = (object.x + object.vx * f1) % right
            t2.y = (object.y + object.vy * f1) % bottom
            dist1 = float(objs.Mass.distance(t1, t2))
            dist2 = float(distances[self.value])
            if (self.operator == "<")  and (dist1 < dist2): ret = 1
            if (self.operator == ">")  and (dist1 > dist2): ret = 1
            if (self.operator == "=="):
                if abs((dist1 - dist2)/((dist1+dist2)/2)) < 0.2: ret = 1
        elif self.comparison == 'dir_dir':
            t1.x = (ship.x + ship.vx * f1) % right
            t1.y = (ship.y + ship.vy * f1) % bottom
            t1.dir = ship.dir
            t2.x = (object.x + object.vx * f1) % right
            t2.y = (object.y + object.vy * f1) % bottom
            dir1 = float(objs.Mass.rel_direction(t1, t2))
            t1.x = (ship.x + ship.vx * f2) % right
            t1.y = (ship.y + ship.vy * f2) % bottom
            t1.dir = ship.dir
            t2.x = (object.x + object.vx * f2) % right
            t2.y = (object.y + object.vy * f2) % bottom
            dir2 = float(objs.Mass.rel_direction(t1, t2))
            dir2 = dir2 + directions[self.value] * (math.pi / 20.0)
            if dir2 > math.pi:
                dir2 -= 2*math.pi
            elif dir2 < -math.pi:
                dir2 += 2*math.pi
            #dir2 = dir2 % (math.pi * 2)
            #print "dir1: %f, dir2: %f" % (dir1, dir2)
            if (self.operator == "<")  and (dir1 < dir2): ret = 1
            if (self.operator == ">")  and (dir1 > dir2): ret = 1
            if (self.operator == "=="):
                if abs(dir1 - dir2) < math.pi/20: ret = 1
        elif self.comparison == 'dir_num':
            t1.x = (ship.x + ship.vx * f1) % right
            t1.y = (ship.y + ship.vy * f1) % bottom
            t1.dir = ship.dir
            t2.x = (object.x + object.vx * f1) % right
            t2.y = (object.y + object.vy * f1) % bottom
            dir1 = float(objs.Mass.rel_direction(t1, t2))
            dir2 = directions[self.value] * (math.pi / 20.0)
            if (self.operator == "<")  and (dir1 < dir2): ret = 1
            if (self.operator == ">")  and (dir1 > dir2): ret = 1
            if (self.operator == "=="):
                if abs(dir1 - dir2) < math.pi/20: ret = 1
        else:
            ret = 0
        return ret

    def mutate(self):
        # Change the comparison
        if random.random() < var.gene_action_add_rate:
            self.comparison = random.choice(comparisons)
        # Change future1
        if random.random() < var.gene_action_copy_rate:
            self.future1 = int(half_gauss(0,40))
        # Change future2
        if random.random() < var.gene_action_remove_rate:
            self.future2 = int(half_gauss(0,40))
        # Change value
        if random.random() < var.gene_action_swap_rate:
            self.value = int(half_gauss(0,40))
        # Change operator
        if random.random() < var.gene_action_change_rate:
            self.operator = random.choice(operators)


    def copy(self):
        copy = Base(c =self.comparison,
                    f1=self.future1,
                    f2=self.future2,
                    v= self.value,
                    op=self.operator)
        return copy

    def __repr__(self):
        #repr = "      base\n"
        repr  = self.comparison + " "
        repr += str(self.future1) + " "
        repr += str(self.future2) + " "
        repr += str(self.value) + " "
        repr += self.operator + " "

        return repr

# Each gene represents a stimulus->response entry
class Gene:
    def __init__(self):
        self.type = 'significant'
        self.base = [] # Up to six bases (conditions)
        self.action = []

    def random(self):
        self.type = random.choice(types)
        for i in range(random.randint(1,var.base_max)):
            base = Base()
            base.random()
            self.base.append(base)
        for i in range(random.randint(1,10)):
            self.action.append(random.choice(actions))

    def test(self, ship, object):
        match = 0
        if (self.type in object.taxonomy):
            match = 1
            for base in self.base:
                if not base.test(ship, object):
                    match = 0
                    break
        return match

    def mutate(self):
        # Change the object type
        if random.random() < var.gene_type_rate:
            self.type = random.choice(types)

        # Add a random base to random location
        if random.random() < var.gene_base_add_rate:
            base = Base()
            base.random()
            location = random.randint(0,len(self.base))
            self.base[location:location] = [base]
        if len(self.base) > 0:
            # Copy a random base to random location
            if random.random() < var.gene_base_copy_rate:
                location = random.randint(0,len(self.base)-1)
                base = self.base[location].copy()
                location = random.randint(0,len(self.base))
                self.base[location:location] = [base]
            # Mutate a base
            if random.random() < var.gene_base_mutate_rate:
                base = random.choice(self.base)
                base.mutate()
        if len(self.base) > 1:
            # Remove a random base
            if random.random() < var.gene_base_remove_rate:
                location = random.randint(0,len(self.base)-1)
                self.base.remove(self.base[location])
            # Swap two bases at random
            if random.random() < var.gene_base_swap_rate:
                l = self.base
                loc1 = random.randint(0,len(self.base)-1)
                loc2 = random.randint(0,len(self.base)-1)
                l[loc1], l[loc2] = l[loc2], l[loc1]

        # Add a random action to random location
        if random.random() < var.gene_action_add_rate:
            action = random.choice(actions)
            location = random.randint(0,len(self.action))
            self.action[location:location] = [action]
        if len(self.action) > 0:
            # Copy a random action to random location
            if random.random() < var.gene_action_copy_rate:
                location = random.randint(0,len(self.action)-1)
                action = self.action[location]
                location = random.randint(0,len(self.action))
                self.action[location:location] = [action]
            # Change an action
            if random.random() < var.gene_action_change_rate:
                loc1 = random.randint(0,len(self.action)-1)
                self.action[loc1] = random.choice(actions)
        if len(self.action) > 1:
            # Remove a random action
            if random.random() < var.gene_action_remove_rate:
                location = random.randint(0,len(self.action)-1)
                self.action.remove(self.action[location])
            # Swap two actions at random
            if random.random() < var.gene_action_swap_rate:
                l = self.action
                loc1 = random.randint(0,len(self.action)-1)
                loc2 = random.randint(0,len(self.action)-1)
                l[loc1], l[loc2] = l[loc2], l[loc1]


    def copy(self):
        copy = Gene()
        copy.type = self.type
        for base in self.base:
            copy.base.append(base.copy())
        for action in self.action:
            copy.action.append(action)
        return copy

    def __repr__(self):
        repr  = "  <\n"
        repr += "    type: " + self.type + "\n"
        repr += "    <\n"
        for base in self.base:
            repr += "      "
            repr += base.__repr__()
            repr += "\n"
        repr += "    >\n"
        repr += "    action: "
        for action in self.action:
            repr += action
            repr += " "
        repr += "\n  >\n"
        return repr

null_gene = Gene()
null_gene.action = ['...']
null_gene.base.append(Base())

# DNA is a list of genes in priority order
class DNA:
    def __init__(self, parents=[]):
        self.gene = []
        self.rating = 1.0
        self.time = 0   # Amount of time in game
        self.damage = 0 # Amount of damage inflicted in game
        self.plays = 0  # Number of opportunities to be in the game

        if len(parents) == 1:
            # Essentially just a copy of another DNA
            for gene in parents[0].gene:
                self.gene.append(gene.copy())
        elif len(parents) == 2:
            # Cross-breed and copy from two other DNA
            p0_start = 0
            p0_end = random.randint(0,len(parents[0].gene))
            p1_start = random.randint(0,len(parents[1].gene))
            p1_end = len(parents[1].gene)
            for i in range(p0_start,p0_end):
                self.gene.append(parents[0].gene[i].copy())
            for i in range(p1_start,p1_end):
                self.gene.append(parents[1].gene[i].copy())

        # Penalize long genes
        for i in range(500, len(self.gene), 100):
            self.rating *= 0.95

    # Create an entirely random DNA
    def random(self):
        for i in range(random.randint(1,10)):
            gene = Gene()
            gene.random()
            self.gene.append(gene)

    # Mutate some aspect of self
    def mutate(self):
        # Add a random gene to random location
        if random.random() < var.dna_add_rate:
            gene = Gene()
            gene.random()
            location = random.randint(0,len(self.gene))
            self.gene[location:location] = [gene]
        if len(self.gene) > 0:
            # Copy a random gene to random location
            if random.random() < var.dna_copy_rate:
                location = random.randint(0,len(self.gene)-1)
                gene = self.gene[location].copy()
                location = random.randint(0,len(self.gene))
                self.gene[location:location] = [gene]
            # Mutate a gene
            if random.random() < var.dna_mutate_rate:
                gene = random.choice(self.gene)
                gene.mutate()
        if len(self.gene) > 1:
            # Remove a random gene
            if random.random() < var.dna_remove_rate:
                location = random.randint(0,len(self.gene)-1)
                self.gene.remove(self.gene[location])
            # Swap two genes at random
            if random.random() < var.dna_swap_rate:
                l = self.gene
                loc1 = random.randint(0,len(self.gene)-1)
                loc2 = random.randint(0,len(self.gene)-1)
                l[loc1], l[loc2] = l[loc2], l[loc1]

    def copy(self):
        copy = DNA()
        for gene in self.gene:
            copy.gene.append(gene.copy())

    def __repr__(self):
        repr = "<\n"
        for gene in self.gene:
            repr += gene.__repr__()
        repr += "\n>"
        return repr

dna_pool = []

def load_dna_pool(file="default"):
    global dna_pool

    # Reset the pool
    dna_pool = []

    file = "ai/" + file
    try:
        ai_file = open(var.get_resource(file))
    except:
        # Not found so randomize the pool
        for i in range(var.population_size):
            dna = DNA()
            dna.random()
            dna_pool.append(dna)
        return

    lines = []
    for line in ai_file.readlines():
        line = line.strip()
        if line != "" and line[0] != "#":
            lines.append(line)

    cur_line = 0
    dna_count = int(lines[cur_line])
    cur_line += 1
    for dna_num in range(dna_count):
        dna = DNA()
        dna_pool.append(dna)
        gene_count = int(lines[cur_line])
        cur_line += 1
        for gene_num in range(gene_count):
            gene = Gene()
            dna.gene.append(gene)
            gene.type = lines[cur_line]
            cur_line += 1
            gene.base = []
            base_count = int(lines[cur_line])
            cur_line += 1
            for base_num in range(base_count):
                base = Base()
                gene.base.append(base)
                parts = lines[cur_line].split(',')
                base.comparison = parts[0]
                base.future1 = int(parts[1])
                base.future2 = int(parts[2])
                base.value = int(parts[3])
                base.operator = parts[4]
                cur_line +=1
            gene.action = lines[cur_line].split(',')
            cur_line +=1
        #print dna

def save_dna_pool(file="new"):
    global dna_pool

    file = "ai/" + file
    ai_file = open(var.get_resource(file), 'w')

    ai_file.write("%d\n" % len(dna_pool))
    for dna in dna_pool:
        ai_file.write("    %d\n" % len(dna.gene))
        for gene in dna.gene:
            ai_file.write("        %s\n" % gene.type)
            ai_file.write("        %d\n" % len(gene.base))
            for base in gene.base:
                ai_file.write("            %s,%d,%d,%d,%s\n" %
                (base.comparison,
                 base.future1,
                 base.future2,
                 base.value,
                 base.operator))
            actions = ",".join(gene.action)
            ai_file.write("        %s\n\n" % actions)
        ai_file.write("\n")

# Based on the ratings, return index into ratings list
# Example: if ratings == [1.0, 2.0] then 1 is twice as
# likely to be returned as 0
def rating_pick(ratings):
    total = 0
    for rating in ratings:
        total += rating
    find = random.random() * total
    for index in range(len(ratings)):
        find -= ratings[index]
        if find < 0.0: break
    return index

generation_count = 0

def runga():
    # Create a new dna pool based on previous pool
    global dna_pool, generation_count

    generation_count += 1
    print "Running GA: ", generation_count

    # Determine the dna fitness ratings based on time and damage
    #   - Find highest damage and highest time alive
    #   - scale all damage and time alive so that highest is 1.0
    #   - Add scaled damage and time together for rating
    times = [x.time for x in dna_pool]
    damages = [x.damage for x in dna_pool]
    tscaler = 1.0 / max(times + [0.00001])
    dscaler = 1.0 / max(damages + [0.00001])
    tscaled = [x * tscaler for x in times]
    dscaled = [x * dscaler for x in damages]
    ratings = [tscaled[i] + dscaled[i] for i in range(len(times))]
    print "ratings: ", ratings

    new_dna_pool = []
    for i in range(var.population_size):
        # Modes of pro-creation:
        # 0. Pure cross-breed
        # 1. Cross-breed with mutation
        # 2. Single with mutation
        mode = rating_pick([var.cross_pure_rate,
                            var.cross_mutate_rate,
                            var.single_mutate_rate])
        if mode == 2: # Single parent
            parents = [dna_pool[rating_pick(ratings)]]
        else: # Cross-breed two parents
            parents = [dna_pool[rating_pick(ratings)],
                       dna_pool[rating_pick(ratings)]]
        dna = DNA(parents=parents)
        if mode in (1,2):
            # Mutate
            dna.mutate()
        new_dna_pool.append(dna)
    dna_pool = new_dna_pool


deaths = 0

def ai_tick(players):
    global deaths

    # Count players, cleanup dead players and ships
    pcount = 0
    kill_players = []
    for player in players:
        if player.ship.dead or deaths > var.deaths_per_generation:
            kill_players.append(player)
            deaths += 1
        else:
            pcount += 1
    for player in kill_players:
        if player.ship in objs.pend:
            objs.pend.remove(player.ship)
        else:
            objs.low.remove(player.ship)
        players.remove(player)

    # Do we need a new generation of DNA?
    if deaths > var.deaths_per_generation:
        objs.virtual.append(objtext.Text('Running GA'))
        deaths = 0
        runga()

    # Replenish if player count drops below 2
    if pcount < 2:
        ptarget = random.randint(2,4)
        for i in range(ptarget - pcount):
            num = random.randint(0,3)
            ship = objs.Ship(num, 10, 10)
            ship.start(dir=random.randint(0, 359))
            ship.find_spot()
            dna = random.choice(dna_pool)
            dna.plays += 1
            players.append(agents.DNAAgent(num, ship,
                dna=dna, objs=[objs.low, objs.high]))
            objs.low.append(ship)

    for player in players:
        player.do_action()


def main():
    import pygame, sys
    import gfx, hud

    # Initialize everything necessary
    pygame.init()
    var.clock = pygame.time.Clock()
    var.ai_train = 1
    full=1
    if '-window' in sys.argv:
        full=0
    gfx.initialize((800,600), full)
    pygame.display.set_caption('Spacewar')

    objs.load_game_resources()
    hud.load_game_resources()
    objtext.load_game_resources()
    load_dna_pool("generation")

    hud = hud.aiHUD()
    hud.setwidth(100)

    players = []

    # Main game event loop
    while 1:
        for event in pygame.event.get():
            if (event.type == pygame.QUIT or
                event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE):
                name = "generation."+ time.strftime("%Y%m%d-%H:%M:%S")
                save_dna_pool(name)
                sys.exit(0)

        objs.runobjects(1.0)
        hud.draw()
        ai_tick(players)
        gfx.update()

        #var.clock.tick(standalone_frame_rate)  # max frame rate

if __name__ == '__main__':
    #standalone_frame_rate = var.frames_per_sec
    main()