import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
from matplotlib import animation
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib.colors import LightSource , colorConverter
import itertools
import random

import warnings
import matplotlib as mpl
from matplotlib import cm

warnings.filterwarnings("ignore")

print(60*"-")
print(" Menger sponge approximation")
print(60*"-")

class MS:
    def Iteration(n,style = "green",save=False):
        """Approximation of the Menger sponge after a determined number of iterations.
        Inputs:
        - n: Int - Number of iterations selected
        - style: Str - Color style of the plot: "green", "red" or "blue" - Default: Green
        - save: Boolean - Saves the produced figure or not - Default: False"""

        if style == "green":
            color_style = ["greenyellow","limegreen","darkgreen"]
            color_style = [[0.8,1,0],[0.0,0.8,0],[0,0.3,0]]
        if style == "red":
            color_style = [[1,1,0],[1,0.5,0],[1,0,0]]
        if style == "blue":
            color_style = [[0.5,1,1] , [0.25,0.5,0.9] , [0,0,0.8]]

        colors = []

        fig = plt.figure(figsize=(8, 6))
        ax = fig.add_subplot(111, projection='3d')
        faces = []
        face_x_0 = np.array([[-1,-1,-1],[-1,-1,1],[-1,1,1],[-1,1,-1]])
        face_x_1 = np.array([[1,-1,-1],[1,-1,1],[1,1,1],[1,1,-1]])
        face_y_0 = np.array([[-1,-1,-1],[1,-1,-1],[1,-1,1],[-1,-1,1]])
        face_y_1 = np.array([[-1,1,-1],[1,1,-1],[1,1,1],[-1,1,1]])
        face_z_0 = np.array([[-1,-1,-1],[-1,1,-1],[1,1,-1],[1,-1,-1]])
        face_z_1 = np.array([[-1,-1,1],[-1,1,1],[1,1,1],[1,-1,1]])
        faces = [face_x_0,face_x_1,face_y_0,face_y_1,face_z_0,face_z_1]

        for f in faces:
            if np.abs(f[:, 0].min() - f[:, 0].max()) < 1e-10:
                colors.append(color_style[0])
            elif np.abs(f[:, 1].min() - f[:, 1].max()) < 1e-10:
                colors.append(color_style[1])
            else:
                colors.append(color_style[2])

        if n > 0:
            for j in range(n):
                colors = []
                faces = [f/3 for f in faces]
                V = list(itertools.product([-1, 0, 1], [-1, 0, 1], [-1, 0, 1]))
                for v in [(0,0,0),(1,0,0),(-1,0,0),(0,1,0),(0,-1,0),(0,0,1),(0,0,-1)]:
                    V.remove(v)

                V = [(2/3)*np.array(list(v)) for v in V]
                V = [v.reshape(3,1)@np.ones((1,4)) for v in V]

                faces_bis = []

                for v in V:
                    for f in faces:
                        if np.abs(f[:, 0].min() - f[:, 0].max()) < 1e-10:
                            colors.append(color_style[0])
                        elif np.abs(f[:, 1].min() - f[:, 1].max()) < 1e-10:
                            colors.append(color_style[1])
                        else:
                            colors.append(color_style[2])
                        f = f + v.T
                        faces_bis.append(f)

                faces = faces_bis

        #ax.add_collection3d(Poly3DCollection(faces, facecolors='green', linewidths=0.02, edgecolors='k', alpha=1.0))
        ax.add_collection3d(Poly3DCollection(faces, facecolors=colors, linewidths=0.00, edgecolors='k', alpha=1.0))


        ax.set_title(f"n = {n}")
        ax.set_xlim(-1,1)
        ax.set_ylim(-1,1)
        ax.set_zlim(-1,1)
        ax.axis("off")
        if save == True:
            plt.savefig("Menger_Sponge_n="+str(n)+".pdf")
            plt.savefig("Menger_Sponge_n="+str(n)+".png")
        plt.show()
        pass
