import numpy as np
import skimage
import matplotlib.pyplot as plt
import matplotlib as mpl
from datetime import datetime

print(150*"_")
print(" ")
print("  FRACTAL DIMENSION")
print(150*"_")
print(" ")

class Tools:
    """Class for tools and auxilliary functions."""

    @staticmethod
    def Line_Matrix(X, idx_1, idx_2):
        """Creates an approximation of a line between two coefficients of an array X.

        Inputs:
        - X: 2D or 3D Array - Input array.
        - idx_1: List of length 2 - Starting coordinate.
        - idx_2: List of length 2 - Ending coordinate.

        Returns an array where coefficients belonging to approximated line between X[idx_1] and X[idx_2] are set to 1."""

        if type(idx_1) != list or type(idx_2) != list:
            raise TypeError("idx_1 and idx_2 must be lists.")

        if len(X.shape) == 2:
            if len(idx_1) != 2 or len(idx_2) != 2:
                raise RuntimeError("idx_1 and idx_2 must be lists of length 2.")
            i_1, j_1 = idx_1
            i_2, j_2 = idx_2

            Delta_i, Delta_j = int(i_2-i_1), int(j_2-j_1)
            if Delta_i == 0 and Delta_j == 0:
                X = X
            else:
                if np.abs(Delta_i) > np.abs(Delta_j):
                    # print("i")
                    if i_1 < i_2:
                        # print("i_1 < i_2")
                        IDX = list(range(0, Delta_i + 1, 1))
                    if i_1 > i_2:
                        # print("i_1 > i_2")
                        IDX = list(range(0, Delta_i - 1, -1))
                    # for t in IDX:
                    #     i_t, j_t = int(i_1 + t), int(j_1 + (t / Delta_i) * Delta_j)
                    #     # print(i_t, j_t)
                    #     X[i_t, j_t] = 1
                    IDX_i, IDX_j = np.array(i_1 + np.array(IDX), dtype=int), np.array(j_1 + np.array(IDX) / Delta_i * Delta_j, dtype=int)
                    X[IDX_i, IDX_j] = 1
                if np.abs(Delta_i) <= np.abs(Delta_j):
                    # print("j")
                    if j_1 < j_2:
                        # print("j_1 < j_2")
                        IDX = list(range(0, Delta_j + 1, 1))
                    if j_1 > j_2:
                        # print("j_1 > j_2")
                        IDX = list(range(0, Delta_j - 1, -1))
                    # for t in IDX:
                    #     i_t, j_t = int(i_1 + (t / Delta_j) * Delta_i), int(j_1 + t)
                    #     # print(i_t, j_t)
                    #     X[i_t, j_t] = 1
                    IDX_i, IDX_j = np.array(i_1 + np.array(IDX) / Delta_j * Delta_i, dtype=int), np.array(j_1 + np.array(IDX), dtype=int)
                    X[IDX_i, IDX_j] = 1
        return X

    @staticmethod
    def divisors(n):
        """Gives the list of the divisors of an integen.

        Inputs:
        - n: Int - Input integer.

        Returns a list containing ordered divisors of n."""

        L = []
        for k in range(1, n+1):
            if n%k == 0:
                L.append(k)
        return L

    @staticmethod
    def name(short_name):
        """Gives the name of the fractal with the short name.

        Inputs:
        - short_name: Str - Short name of the fractal.
            > "SC": Sierpinski Carpet.
            > "RF": Rivera Fractal.
            > "BM": Brownian Motion.
            > "KS": Koch Snowflake.
        """
        if short_name == "SC":
            nme = "Sierpinski_Carpet"
        if short_name == "RF":
            nme = "Rivera_Fractal"
        if short_name == "BM":
            nme = "Brownian_Motion"
        if short_name == "KS":
            nme = "Koch_Snowflake"
        return nme

class Fractals:
    """Class for fractal generation."""

    @staticmethod
    def Sierpinski_Carpet():
        """Creates Sierpinski carpet"""
        date_now = datetime.now().strftime("%d-%m-%Y-%H-%M-%S")
        S = np.ones((1,1))
        for _ in range(7):
            S = np.concatenate((np.concatenate((S,S,S), axis=1), np.concatenate((S, np.zeros_like(S), S), axis=1), np.concatenate((S,S,S), axis=1)), axis=0)
        params = {'name':"Sierpinski Carpet", 'fractal':S, 'cells_size':[1,3,9,27,81,243,729,2187], 'date':date_now}
        path = r"C:\Documents\Implementaitons\TempProjects\Fractals_Dictionnaries" + "\Sierpinski_Carpet"
        np.save(path, params)
        return None

    @staticmethod
    def Rivera_Fractal():
        """Creates Rivera fractal"""
        date_now = datetime.now().strftime("%d-%m-%Y-%H-%M-%S")
        S = np.ones((1, 1))
        for _ in range(7):
            S = np.concatenate((np.concatenate((S, np.zeros_like(S), S), axis=1), np.concatenate((S, S, S), axis=1),
                                np.concatenate((S, np.zeros_like(S), S), axis=1)), axis=0)
        params = {'name': "Rivera Fractal", 'fractal': S, 'cells_size': [1, 3, 9, 27, 81, 243, 729, 2187], 'date':date_now}
        path = r"C:\Documents\Implementaitons\TempProjects\Fractals_Dictionnaries" + "\Rivera_Fractal"
        np.save(path, params)
        return None

    @staticmethod
    def Brownian_Motion():
        """Curve of a Brownian motion [Wiener Process], widely used to,solve SDE's."""
        date_now = datetime.now().strftime("%d-%m-%Y-%H-%M-%S")
        a, p = 2, 13
        q = p-2
        delta_t = a ** (-p)
        N = a ** p
        B = np.zeros(int(N))
        for k in range(N-1):
            B[k+1] = B[k] + np.random.normal(loc=0, scale=np.sqrt(delta_t ** 1))
        B_min, B_max = np.min(B), np.max(B)
        N_B = int((B_max-B_min)/delta_t)+1
        S = np.zeros((N_B, N))

        i_1, j_1 = int((B[0] - B_min)/delta_t) , 0

        for n in range(N-1):
            S[i_1, j_1] = 1
            i_2, j_2 = int((B[n] - B_min)/delta_t) , (n + 1)
            S = Tools.Line_Matrix(S, [i_1, j_1], [i_2, j_2])
            i_1, j_1 = i_2, j_2

        F = np.zeros(( (N_B // (a ** q) + 1) * a ** q, N ))
        F[0:N_B, :] = S

        CS = Tools.divisors(np.gcd(F.shape[0], F.shape[1]))
        params = {'name': "Brownian Motion", 'fractal': F, 'cells_size': CS, 'date':date_now}
        path = r"C:\Documents\Implementaitons\TempProjects\Fractals_Dictionnaries" + "\Brownian_Motion"
        np.save(path, params)
        return None

    @staticmethod
    def Koch_Snowflake():
        """Curve of a Koch snowflake"""
        date_now = datetime.now().strftime("%d-%m-%Y-%H-%M-%S")

        # Build points belonging to iterations of the fractal
        X = np.array([[0, np.sqrt(3) / 2, -np.sqrt(3) / 2, 0], [-1, 1 / 2, 1 / 2, -1]])
        KS = [X]

        J = np.array([[0, -1], [1, 0]])
        n_iter_KS = 10
        for j in range(n_iter_KS):
            print(" > Iteration: " + str(j+1) + " / " + str(n_iter_KS) + 20 * " " , end="\r")
            A, B = X[:, :-1], X[:, 1:]
            A1, B1 = (2 / 3) * A + (1 / 3) * B, (1 / 3) * A + (2 / 3) * B
            C = (A + B) / 2 - (np.sqrt(3) / 6) * J @ (B - A)
            X = np.zeros((2, 4*(X.shape[1] - 1)))
            X[:, ::4] = A
            X[:, 1::4] = A1
            X[:, 2::4] = C
            X[:, 3::4] = B1
            X = np.concatenate((X, B[:, -1].reshape(2, 1)), axis=1)

        # Convert points to an array
        N = 10000 # Size of the array
        F = np.zeros((N, N))
        x1, y1 = X[:, 0]
        i1 = int(0.5 * N + 0.4 * N * y1)
        j1 = int(0.5 * N + 0.4 * N * x1)

        for k in range(np.shape(X)[1]-2):
            print(" > Matrix completion: " + str(np.round(100 * (k+1) / (np.shape(X)[1] - 1), decimals=2)) + " %" + 20 * " ", end="\r")
            x2, y2 = X[:, k+1]
            i2 = int(0.5 * N + 0.4 * N * y2)
            j2 = int(0.5 * N + 0.4 * N * x2)
            F = Tools.Line_Matrix(F, [i1, j1], [i2, j2])
            x1, y1, i1, j1 = x2, y2, i2, j2

        CS = [1, 2, 4, 5, 10, 20, 25, 40, 50, 100, 125, 200, 250, 500, 1000, 1250, 2500, 5000, 10000]
        params = {'name': "Koch Snowflake", 'fractal': F, 'cells_size': CS, 'date':date_now}
        path = r"C:\Documents\Implementaitons\TempProjects\Fractals_Dictionnaries" + "\Koch_Snowflake"
        np.save(path, params)
        return None

class Dimension:
    """Class for fractal dimension computation"""

    @staticmethod
    def Box_Counting(name_fractal, save_fig = False):
        """Computes Fractal Dimension by using Box counting method.

        Inputs:
        - name_fractal: str - Name of the desired fractal.
            > "SC": Sierpinski Carpet.
            > "RF": Rivera Fractal.
            > "BM": Brownian Motion.
            > "KS": Koch Snowflake.
        - savefig: Boolean - Saves figures or not. Default: False.
        """

        params = np.load(r"C:\Documents\Implementaitons\TempProjects\Fractals_Dictionnaries\\" + Tools.name(name_fractal) + ".npy", allow_pickle=True).item()

        F = params['fractal']
        dim = np.shape(F)
        SIZE, COUNT = np.array([]), np.array([])

        print(150*"_")
        print(" ")
        print("  - Box counting process...")
        print(150*"_")
        print(" ")

        if len(dim) == 2:
            for k in range(len(params['cells_size']))[::-1]:
                print("    > Step " + str(len(params['cells_size']) - k) + " / " + str(len(params['cells_size'])), end="\r")
                cs = params['cells_size'][k]
                G = skimage.util.view_as_blocks(F, (cs,cs))
                count = np.sum(np.max(G, axis=(2,3)))
                SIZE, COUNT = np.concatenate((SIZE, np.array([cs/max(params['cells_size'])])), axis=0), np.concatenate((COUNT, np.array([count])), axis=0)

                if k >= len(params['cells_size']) - 6:
                    FF = np.copy(F)
                    H = np.where(G==1)
                    K = np.vstack([H[1], H[0]])
                    K = np.unique(K, axis=1)
                    Kx, Ky = np.zeros((K.shape[1], 5)), np.zeros((K.shape[1], 5))
                    for j in range(K.shape[1]):
                        ix, iy = K[0, j], K[1, j]
                        # FF[cs * iy :cs * (iy + 1), cs * ix] = 2
                        # FF[cs * iy:cs * (iy + 1), cs * (ix + 1) - 1] = 2
                        # FF[cs * iy, cs * ix:cs * (ix + 1)] = 2
                        # FF[cs * (iy + 1) - 1, cs * ix:cs * (ix + 1)] = 2
                        Kx[j, :] = np.array([cs * ix, cs * (ix + 1) - 1, cs * (ix + 1) - 1, cs * ix, cs * ix])
                        Ky[j, :] = np.array([cs * iy, cs * iy, cs * (iy + 1) - 1, cs * (iy + 1) - 1, cs * iy])

                    if F.shape[1] > F.shape[0]:
                        grid_1 = np.array([np.arange(0, F.shape[1] + cs, cs), np.arange(0, F.shape[1] + cs, cs)]).T
                        grid_2 = np.array([0 * np.ones((F.shape[1] + cs) // cs), F.shape[1] * np.ones((F.shape[1] + cs) // cs)]).T
                    elif F.shape[1] <= F.shape[0]:
                        grid_1 = np.array([np.arange(0, F.shape[0] + cs, cs), np.arange(0, F.shape[0] + cs, cs)]).T
                        grid_2 = np.array([0 * np.ones((F.shape[0] + cs) // cs), F.shape[0] * np.ones((F.shape[0] + cs) // cs)]).T

                    plt.figure(figsize=(12,5))

                    plt.subplot(1,2,1)
                    plt.xscale('log')
                    plt.yscale('log')
                    plt.scatter(1/SIZE, COUNT, marker="s", color="green", label="Count")
                    plt.grid()
                    plt.xlabel("$\epsilon^{-1}$")
                    plt.ylabel("$N(\epsilon)$")
                    plt.legend()
                    plt.title("Box-counting dimension estimation for " + params['name'])

                    plt.subplot(1,2,2)
                    plt.plot(grid_1.T, grid_2.T, color="red", linewidth=1)
                    plt.plot(grid_2.T, grid_1.T, color="red", linewidth=1)
                    plt.plot([0, F.shape[1]-1], [F.shape[0]-1, F.shape[0]-1], color="red", linewidth=1)
                    plt.plot([F.shape[1]-1, F.shape[1]-1], [0, F.shape[0]-1], color="red", linewidth=1)
                    plt.plot(Kx.T, Ky.T, color="green", linewidth="2")
                    plt.imshow(FF, cmap="grey")
                    plt.axis("off")

                    if save_fig == True:
                        plt.savefig(r"C:\Documents\Implementaitons\TempProjects\Figures_Fractals\Fractal_Dimension_Box_Counting\Dimension_Box_Counting_" + Tools.name(name_fractal) + "_" + str(k) + "_" + params['date'] + ".pdf", dpi=1500)
                        plt.savefig(r"C:\Documents\Implementaitons\TempProjects\Figures_Fractals\Fractal_Dimension_Box_Counting\Dimension_Box_Counting_" + Tools.name(name_fractal) + "_" + str(k) + "_" + params['date'] + ".png", dpi=1500)
                    plt.show()

        LOG_SIZE, LOG_COUNT = np.log(SIZE), np.log(COUNT)
        M_S, M_C, M_SC, M_S2 = np.mean(LOG_SIZE), np.mean(LOG_COUNT), np.mean(LOG_SIZE*LOG_COUNT), np.mean(LOG_SIZE**2)
        DIM_BOX = -(M_SC-M_S*M_C)/(M_S2-M_S**2)

        print(150 * "_")
        print(" ")
        print("   > Dimension [estimation - Box counting]: ", int(DIM_BOX * 10000) / 10000)
        print(150 * "_")


        plt.figure(0)
        plt.xscale('log')
        plt.yscale('log')
        plt.scatter(1/SIZE, COUNT, marker="s", color="green", label="Count")
        plt.grid()
        plt.xlabel("$\epsilon^{-1}$")
        plt.ylabel("$N(\epsilon)$")
        plt.legend()
        plt.title("Box-counting dimension estimation for " + params['name'] + " - " + str(int(DIM_BOX * 10000) / 10000))
        if save_fig == True:
            plt.savefig(r"C:\Documents\Implementaitons\TempProjects\Figures_Fractals\Fractal_Dimension_Box_Counting\Dimension_Box_Counting_" + Tools.name(name_fractal) + "_" + params['date'] + ".pdf", dpi=1500)
            plt.savefig(r"C:\Documents\Implementaitons\TempProjects\Figures_Fractals\Fractal_Dimension_Box_Counting\Dimension_Box_Counting_" + Tools.name(name_fractal) + "_" + params['date'] + ".png", dpi=1500)
        plt.show()

        return None

class Print:
    """Class for printing saved figures"""

    @staticmethod
    def fractal(name_fractal, save_fig = False):
        """Print fractal.

        Inputs:
        - name_fractal: str - Name of the desired fractal.
            > "SC": Sierpinski Carpet.
            > "RF": Rivera Fractal.
            > "BM": Brownian Motion.
            > "KS": Koch Snowflake.
        - save_fig: Boolean - Saves the figure or not. Default: False.
        """

        params = np.load(r"C:\Documents\Implementaitons\TempProjects\Fractals_Dictionnaries\\" + Tools.name(name_fractal) + ".npy", allow_pickle=True).item()
        plt.figure(0)
        plt.imshow(params['fractal'], cmap="gray")
        plt.axis("off")
        if save_fig == True:
            plt.savefig(r"C:\Documents\Implementaitons\TempProjects\Figures_Fractals\\" + Tools.name(name_fractal) + "_" + params['date'] + ".pdf", dpi = 1500)
            plt.savefig(r"C:\Documents\Implementaitons\TempProjects\Figures_Fractals\\" + Tools.name(name_fractal) + "_" + params['date'] + ".png", dpi = 1500)
        plt.show()
        return None