import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
from matplotlib import animation
import itertools
import random

import warnings
import matplotlib as mpl
from matplotlib import cm

warnings.filterwarnings("ignore")

print(60*"-")
print(" Hilbert Curve approximation")
print(60*"-")

class Creation:
    def Iteration(n):
        """Creates the points which belong to the approximation of Hilbert Curve at order n.
        Inputs:
        - n: Int - Number of iterations"""
        X = np.array([[-0.5 , -0.5 , 0.5 , 0.5],[-0.5 , 0.5 , 0.5 , -0.5]])
        HC = [X]

        P_alpha , P_beta , P_gamma , P_delta = np.array([[0,1],[1,0]]) , np.array([[1,0],[0,1]]) , np.array([[1,0],[0,1]]) , np.array([[0,-1],[-1,0]])

        u_alpha , u_beta , u_gamma , u_delta = np.array([[-1],[-1]]) , np.array([[-1],[1]]) , np.array([[1],[1]]) , np.array([[1],[-1]])

        for j in range(n-1):
            X_alpha =  P_alpha@X + u_alpha@np.ones_like(X[0,:]).reshape(1,X.shape[1])
            X_beta =  P_beta@X + u_beta@np.ones_like(X[0,:]).reshape(1,X.shape[1])
            X_gamma =  P_gamma@X + u_gamma@np.ones_like(X[0,:]).reshape(1,X.shape[1])
            X_delta =  P_delta@X + u_delta@np.ones_like(X[0,:]).reshape(1,X.shape[1])

            X = (1/2)*np.concatenate((X_alpha , X_beta , X_gamma , X_delta) , axis = 1)

            HC.append(X)

        def save_list_of_arrays(lst, filename):
            np.savez(filename, *lst)

        save_list_of_arrays(HC,"Hilbert_Curve_n="+str(n)+".npz")
        pass

class Plot:
    def Plot_Iter(n,save=False):
        """Plot the last iteration of the approximation process of an Hilbert Curve.
        Inputs:
        - n: Int - Number of iterations required
        - save: Boolean - Saves the figure or not. Default: False"""

        name_HC = "Hilbert_Curve_n="+str(n)+".npz"

        def load_list_of_arrays(filename):
            npzfile = np.load(filename)
            keys = sorted(npzfile.files)
            return [npzfile[key] for key in keys]

        X = load_list_of_arrays(name_HC)[-1]

        X = (X + np.array([[1],[1]])@np.ones((1,X.shape[1])))/2

        #cmap = cm.get_cmap("viridis")
        #colors = np.linspace(0,1,X.shape[1])
        colors = mpl.cm.rainbow(np.linspace(0,1,X.shape[1]))

        plt.figure()
        plt.axes(aspect="equal")
        plt.title("n="+str(n))
        for j in range(X.shape[1]):
            plt.plot(X[0,j:j+2], X[1,j:j+2] , color = colors[j])
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        if save == True:
            plt.savefig("Hilbert_Curve_n="+str(n)+".pdf")
        plt.show()


        pass