import numpy as np
from scipy import signal
import matplotlib.pyplot as plt

import time
import datetime
from datetime import datetime as dtime

print(60*"-")
print(" Cellular automaton modelling Conus")
print(60*"-")

# Execution of the Cellular automaton

class Tools:
    """Class of the main tools used in this code"""

    def delta(i,j):
        """Kronecker symbol.
        Inputs:
        - i: Int - First input
        - j: Int - Second input"""
        return signal.unit_impulse(max(i+1,j+1),i)[j]

    def modulo(k,N):
        """Gives the congruence class of an integer.
        Inputs:
        - k: Int - Integer for the computation of congruence class
        - N: Int - Congruence class
        Returns k modulo N"""
        return k % N

class Conus_Modelling:
    """Class for the modelling of Conus"""
    def grid(Nx ,Ny):
        """Creates a grid modelling the triangles of the conus.
        Inputs:
        Nx: Int - Width of the grid.
        Ny: Int - Length of the grid.
        Returns an array of shape (Nx,Ny) which contains 0 and 1
        as ccoefficients."""

        A = np.int64(np.zeros((Ny,Nx)))

        K = Nx*(Ny-1) # Number of computation which have to be done
        pow = max([int(np.log10(K) - 1), 3])
        pow = min([pow, 6])


        A[0, :] = np.int64(np.random.binomial(n = 1, p = 0.5, size=(Nx,)))

        start_time = time.time()

        for i in range(Ny-1):
            for j in range(Nx):
                k = i*Nx + j + 1
                end_time = start_time + (K / (k + 1)) * (time.time() - start_time)
                end_time = datetime.datetime.fromtimestamp(int(np.round(end_time))).strftime(' %Y-%m-%d %H:%M:%S')
                print(" Loading :  {} % \r".format(str(int(10 ** (pow) * (k + 1) / K) / 10 ** (pow - 2)).rjust(3)), " Estimated time for ending : " + end_time, " - ", end="")
                A[i+1,j] = Tools.delta(A[i,Tools.modulo(j+1,Nx)] , A[i,j])*Tools.delta(A[i,Tools.modulo(j-1,Nx)] , A[i,j])
        return A


    def grid_rule_90(Nx , Ny):
        """Creates a grid modelling with rule 90 of elementary cellular autonamton.
        Inputs:
        Nx: Int - Width of the grid.
        Ny: Int - Length of the grid.
        Returns an array of shape (Nx,Ny) which contains non negative floats
        as ccoefficients."""

        A = np.int64(np.zeros((Ny,Nx)))

        K = Nx*(Ny-1) # Number of computation which have to be done
        pow = max([int(np.log10(K) - 1), 3])
        pow = min([pow, 6])

        start_time = time.time()

        #A[0, :] = np.int64(np.random.binomial(n = 1, p = 0.5, size=(Nx,)))
        A[0,Nx//2] = 1

        for i in range(Ny-1):
            for j in range(Nx):
                k = i*Nx + j + 1
                end_time = start_time + (K / (k + 1)) * (time.time() - start_time)
                end_time = datetime.datetime.fromtimestamp(int(end_time)).strftime(' %Y-%m-%d %H:%M:%S')
                print(" Loading :  {} % \r".format(str(int(10 ** (pow) * (k + 1) / K) / 10 ** (pow - 2)).rjust(3)), " Estimated time for ending : " + end_time, " - ", end="")
                x , y , z = A[i,Tools.modulo(j+1,Nx)] , A[i,j] , A[i,Tools.modulo(j-1,Nx)]
                if (x,y,z) == (1,1,0) or (x,y,z) == (1,0,0) or (x,y,z) == (0,1,1) or (x,y,z) == (0,0,1):
                    A[i+1,j] = 0
                else:
                    A[i+1,j] = 1
        return A


    def Plot(Nx , Ny , type = "Conus" , save = "False"):
        """Creates and plots a grid modelling a Conus.
        Inputs:
        Nx: Int - Width of the grid.
        Ny: Int - Length of the grid.
        type: Str - Type of modelling. Default: "Conus".
        save: Boolean - Saves the figure or not. Default: False"""

        if type == "Conus":
            A = Conus_Modelling.grid(Nx,Ny)
        if type == "Rule_90":
            A = Conus_Modelling.grid_rule_90(Nx,Ny)

        plt.figure()
        if type == "Conus":
            plt.imshow(A , cmap = "gray")
        if type == "Rule_90":
            plt.imshow(A , cmap = "gray")
        if save == True:
            plt.savefig(type+"_Nx_"+str(Nx)+"_Ny_"+str(Ny)+".pdf",dpi = (1000))
        plt.show()
        pass


