import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
from matplotlib import animation
import itertools
import random

import warnings
import matplotlib as mpl
from matplotlib import cm

warnings.filterwarnings("ignore")

print(60*"-")
print(" Koch snowflake approximation")
print(60*"-")

class Creation:
    def Iteration(n):
        """Creates the points which belong to the approximation of Koch snowflake at order n.
        Inputs:
        - n: Int - Number of iterations"""
        X = np.array([[0 , np.sqrt(3)/2 , -np.sqrt(3)/2 , 0],[-1 , 1/2 , 1/2 , -1]])
        KS = [X]

        J = np.array([[0,-1],[1,0]])

        for j in range(n):
            Y = np.zeros((2,1))
            for k in range(np.shape(X)[1]-1):
                A , B = X[:,k].reshape(2,1) , X[:,k+1].reshape(2,1)
                A1 , B1 = (2/3)*A + (1/3)*B , (1/3)*A + (2/3)*B
                C = (A+B)/2 - (np.sqrt(3)/6)*J@(B-A)
                Y = np.concatenate((Y,A,A1,C,B1),axis=1)

            Y = np.concatenate((Y,X[:,-1].reshape(2,1)),axis=1)
            Y = Y[:,1:]

            KS.append(Y)
            X = Y

        def save_list_of_arrays(lst, filename):
            np.savez(filename, *lst)

        save_list_of_arrays(KS,"Koch_Snowflake_n="+str(n)+".npz")
        pass

class Plot:
    def Plot_Iter(n,save=False):
        """Plot the last iteration of the approximation process of an Koch snowflake.
        Inputs:
        - n: Int - Number of iterations required
        - save: Boolean - Saves the figure or not. Default: False"""

        name_KS = "Koch_Snowflake_n="+str(n)+".npz"

        def load_list_of_arrays(filename):
            npzfile = np.load(filename)
            keys = sorted(npzfile.files)
            return [npzfile[key] for key in keys]

        X = load_list_of_arrays(name_KS)[-1]

        #X = (X + np.array([[1],[1]])@np.ones((1,X.shape[1])))/2

        #cmap = cm.get_cmap("viridis")
        #colors = np.linspace(0,1,X.shape[1])
        colors = mpl.cm.rainbow(np.linspace(0,1,X.shape[1]))

        plt.figure()
        plt.axes(aspect="equal")
        plt.title("n="+str(n))
        for j in range(X.shape[1]):
            plt.plot(X[0,j:j+2], X[1,j:j+2] , color = colors[j])
        plt.xlim(-1, 1)
        plt.ylim(-1, 1)
        if save == True:
            plt.savefig("Koch_Snowflake_n="+str(n)+".pdf")
        plt.show()


        pass