import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
from matplotlib import animation
import itertools
import random

import warnings
import matplotlib as mpl
import time
import datetime
from datetime import datetime as dtime

warnings.filterwarnings("ignore")

print(60*"-")
print(" Maze generation and resolution algorithm")
print(60*"-")

class Generation:
    def creation(Nx,Ny):
        """Generates a Maze pf shape Nx x Ny by using randomize Kruskal's algorithm aznd iterations of the algoritm
        Inputs:
        - Nx: Int - Length of the grid
        - Ny: Int - Height of the grid
        Saves the grid under the form a array of shape (2*Nx+1,2*Ny+1) so as to take walls into account and the
       iterations of the Kruskal's algorithm"""

        G = np.zeros((2*Ny+1,2*Nx+1))
        G[1::2,1::2] = np.array(list(range(1,Nx*Ny+1))).reshape(Ny,Nx)
        G[1, 0], G[-2, -1] = Nx * Ny, Nx * Ny
        G = np.int64(G)
        G_New , Hist_G = G , []
        IDX = list(itertools.product(list(range(2,2*Nx-1,2)) , list(range(1,2*Ny,2)))) + list(itertools.product(list(range(1,2*Nx,2)) , list(range(2,2*Ny-1,2))))
        IDX = random.sample(IDX,len(IDX))
        IDX_sparse = random.sample(IDX,len(IDX)//20)
        for k in range(len(IDX)):
            print(" Cells not connected :  {}  \r".format(np.size(np.where((G > 0) & (G < Nx*Ny)))), end="  ")
            i = IDX[k][0]
            if i%2 == 0:
                j = IDX[k][1]
                if G[j,i-1] != G[j,i+1]:
                    max_value , min_value = max( G[j,i-1] , G[j,i+1]) , min( G[j,i-1] , G[j,i+1])
                    idx = np.where(G == min_value)
                    G[idx] = max_value
                    G[j,i] = max_value
            if i%2 == 1:
                j = IDX[k][1]
                if G[j-1,i] != G[j+1,i]:
                    max_value , min_value = max( G[j-1,i] , G[j+1,i]) , min( G[j-1,i] , G[j+1,i])
                    idx = np.where(G == min_value)
                    G[idx] = max_value
                    G[j,i] = max_value
            Hist_G.append(np.matrix(G))
        for k in range(len(IDX_sparse)):
            i , j = IDX_sparse[k][0] , IDX_sparse[k][1]
            G[j,i] = Nx*Ny
            Hist_G.append(np.matrix(G))
        np.save("Maze_Nx="+str(Nx)+"_Ny="+str(Ny) , (G,Hist_G) , allow_pickle = True)
        pass

    def plot_creation(name,save=False):
        """Plots a Maze of shape Nx x Ny by using randomize Kruskal's algorithm
        Inputs:
        - name: Str - Name of the loaded grid to plot
        - save: Boolean - Saves the figure or not. Default: False
        Returns the grid under the form a array of shape (2*Nx+1,2*Ny+1) so as to take walls into account"""
        G = np.load(name , allow_pickle = True)[0]
        fig = plt.imshow(G)
        fig.set_cmap("nipy_spectral")
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        if save == True:
            Ny, Nx = np.shape(G)[0] // 2, np.shape(G)[1] // 2
            plt.savefig("Maze_Nx="+str(Nx)+"_Ny="+str(Ny)+".pdf",dpi=(500))
        plt.show()
        pass

    def animation_creation(name , save = False):
        """Plots the iterations for buiding of a Maze of shape Nx x Ny by using randomize Kruskal's algorithm
        Inputs:
        - name: Str - Name of the loaded animation to plot
        - save: Boolean - Saves the animation of not. Default: False
        Returns the grids under the form arrays of shape (2*Nx+1,2*Ny+1) so as to take walls into account"""
        G , List_G = np.load(name , allow_pickle = True)

        N = len(List_G)
        deltat = 100  # Duration between two frames [s]

        def animate(n):
            im.set_array(List_G[n])
            return [im]

        fig , ax = plt.subplots()
        im = plt.imshow(List_G[0], cmap="nipy_spectral")
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        plt.title("Maze Building")
        anim = animation.FuncAnimation(fig, animate, frames=N, blit=True, interval=deltat, repeat=True)
        #%matplotlib qt
        fig.tight_layout()
        if save == True:
            Nx, Ny = np.shape(G)[0] // 2, np.shape(G)[1] // 2
            anim.save("Maze_building_Nx=" + str(Nx) + "_Ny=" + str(Ny) + ".gif", writer="pillow")
        plt.show()

        pass

class Solve:
    def maze_solver(name):
        """Solves a Maze via "diatance from exit" method. Returns the final maze solved and iterations of the method.
        Inputs:
        - name: Str - Name of the loaded grid to solve
        Saves the grid under the form a array of shape (2*Nx+1,2*Ny+1) so as to take walls into account and the
        iterations of the Kruskal's algorithm"""

        G = np.load(name , allow_pickle = True)[0]
        Hist_G = []
        Ny , Nx = np.shape(G)[0]//2 , np.shape(G)[1]//2

        # Crossing the maze to evamuate the distance to exit

        idx_0 , idx = np.where(G == 0) , np.where(G > 0)
        G[idx_0] , G[idx] = -1 , 0
        G[-2,-1] = 1
        G[-2,-2] = 2
        dist = 2
        while np.size(np.where(G == 0)) > 0:
            idx = np.where(G == dist)
            for k in range(np.size(idx[0])):
                i , j = idx[0][k] , idx[1][k]
                if G[i+1,j] == 0:
                    G[i+1,j] = dist+1
                if G[i-1,j] == 0:
                    G[i-1,j] = dist+1
                if G[i,j+1] == 0:
                    G[i,j+1] = dist+1
                if G[i,j-1] == 0:
                    G[i,j-1] = dist+1
            dist += 1
            Hist_G.append(np.matrix(G))

        # Escape from the Maze

        M = np.max(G)

        idx_runner = (1,0)
        dist = G[idx_runner]
        G[idx_runner] = 2*M
        Hist_G.append(np.matrix(G))
        idx_runner = (1,1)
        G[idx_runner] = 2*M
        dist = dist - 1



        while idx_runner[0] != 2*Ny-1 or idx_runner[1] != 2*Nx:# and dist > 0:
            if G[idx_runner[0] - 1, idx_runner[1]] == dist - 1:
                G[idx_runner[0] - 1, idx_runner[1]] = 2 * M
                idx_runner = idx_runner[0] - 1, idx_runner[1]
            if G[idx_runner[0] + 1, idx_runner[1]] == dist - 1:
                G[idx_runner[0] + 1, idx_runner[1]] = 2 * M
                idx_runner = idx_runner[0] + 1, idx_runner[1]
            if G[idx_runner[0], idx_runner[1] - 1] == dist - 1:
                G[idx_runner[0], idx_runner[1] - 1] = 2 * M
                idx_runner = idx_runner[0], idx_runner[1] - 1
            if G[idx_runner[0], idx_runner[1] + 1] == dist - 1:
                G[idx_runner[0], idx_runner[1] + 1] = 2 * M
                idx_runner = idx_runner[0], idx_runner[1] + 1
            dist = dist - 1
            Hist_G.append(np.matrix(G))

        np.save("Solved_Maze_Nx="+str(Nx)+"_Ny="+str(Ny) , (G,Hist_G) , allow_pickle = True)
        pass

    def plot_solver(name,save=False):
        """Plots a solved Maze of shape Nx x Ny by using randomize Kruskal's algorithm
        Inputs:
        - name: Str - Name of the loaded solved grid to plot
        - save: Boolean - Saves the figure or not. Default: False
        Returns the grid under the form a array of shape (2*Nx+1,2*Ny+1) so as to take walls into account"""



        G = np.load(name, allow_pickle=True)[0]
        fig = plt.imshow(G)
        List_Colors = [(0,"black"),(0.00001,"black"),(0.00001,"yellow"),(0.25,"red"),(0.25,"red"),(0.55,"indigo"),(0.55,"green"),(1.0,"green")]
        Cmap = mpl.colors.LinearSegmentedColormap.from_list("", List_Colors)
        fig.set_cmap(Cmap)
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        if save == True:
            Ny, Nx = np.shape(G)[0] // 2, np.shape(G)[1] // 2
            plt.savefig("Solved_Maze_Nx="+str(Nx)+"_Ny="+str(Ny)+".pdf",dpi=(500))
        plt.show()
        pass

    def animation_solver(name,save=False):
        """Plots the iterations for solving of a Maze of shape Nx x Ny by using distance from exit method
        Inputs:
        - name: Str - Name of the loaded animation to plot
        - save: Boolean - Saves the animation of not. Default: False
        Returns the grids under the form arrays of shape (2*Nx+1,2*Ny+1) so as to take walls into account"""
        G , List_G = np.load(name , allow_pickle = True)

        N = len(List_G)

        #for n in range(N):
        #    idx_0 = np.where(List_G[n] == 0)
        #    List_G[n][idx_0] = 1

        M = np.max(G)

        for n in range(N):
            idx_m1 , idx_0 , idx , idx_M = np.where(List_G[n] == -1) , np.where(List_G[n] == 0) , np.where((List_G[n] > 0) & (List_G[n] < M)) , np.where(List_G[n] == M)
            List_G[n][idx_m1] , List_G[n][idx_0] , List_G[n][idx] , List_G[n][idx_M] = -1 , -1 + int((M+1)/4) ,  -1 + int((M+1)/4) + 2*np.int64(((M+1)/(2*M))*List_G[n][idx]) , M

        #print(List_G[-1])
        #print(M)

        deltat = 20  # Duration between two frames [ms]

        def animate(n):
            im.set_array(List_G[n])
            return [im]

        fig , ax = plt.subplots()
        #List_Colors = [(0, "black"),(1/M, "white"),(1/M, "white"),(2/M, "white"),(2/M, "yellow"), (0.25, "red"), (0.25, "red"), (0.55, "indigo"), (0.55, "green"), (1.0, "green")]
        List_Colors = [(0, "black"),(0.1, "black"),(0.1, "white"),(0.25, "white"),(0.25, "yellow"), (0.5, "red"), (0.5, "red"), (0.75, "indigo"), (0.75, "green"), (1.0, "green")]
        Cmap = mpl.colors.LinearSegmentedColormap.from_list("", List_Colors)
        im = plt.imshow(List_G[0] , cmap = Cmap , vmin = -1 , vmax = M)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        plt.title("Maze Solving")
        #plt.colorbar()
        anim = animation.FuncAnimation(fig, animate, frames=N, blit=True, interval=deltat, repeat=True)
        if save == True:
            Ny, Nx = np.shape(G)[0] // 2, np.shape(G)[1] // 2
            anim.save("Maze_solving_Nx="+str(Nx)+"_Ny="+str(Ny)+".gif",writer="pillow")
        #%matplotlib qt
        fig.tight_layout()
        plt.show()

        pass