import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from matplotlib.animation import FuncAnimation
import sys


# Modelling of several species in competition with a particular point of view (chicken, fox and snake)

class Computation:
    """Class for computation of trajectories of particles"""

    def W(self, a, b):
        """Gives the winner between two classes indexed by a number.
        Inputs:
        - a: Int - Label of the first competitor
        - b: Int - Label of the second competitor"""
        MW = np.array([[0,1,0],[1,1,2],[0,2,2]])
        return MW[a,b]

    def Trajectories(self, J = 100, N = 100, delta = 0.05, v=0.01):
        """Computes trajectories of particles.
        Inputs:
        - J: Int - Number of particles. Default: 100
        - N: Int - Number of time iterations. Default: 100
        - delta: Float - Smaller distance considered as separating two particles.
        - v: Float - Speed of particle movement. Default: 0.01"""

        X, Y = np.zeros((J,N)), np.zeros((J,N))
        L = np.zeros((J,N),dtype='int64')
        F = np.zeros((3,N))
        X[:, 0], Y[:, 0] = np.random.uniform(low=0, high=1, size=(J,)), np.random.uniform(low=0, high=1, size=(J,))
        L[:,0] = np.random.randint(low = 0 , high = 3 , size = (J,))
        X[:, 0], Y[:, 0] = 0.5 + 0.15*(2*X[:, 0]-1) + 0.25*np.cos(2*L[:,0]*np.pi/3), 0.5 + 0.15*(2*Y[:, 0]-1) + 0.25*np.sin(2*L[:,0]*np.pi/3)
        for idx in range(3):
            F[idx,0] = np.count_nonzero(L[:,0]==idx)

        print("Iterations:")
        for n in range(N-1):
            nn = n+2
            sys.stdout.write("\r%d " % nn + "/"+str(N))
            sys.stdout.flush()
            theta = np.random.uniform(low = 0, high = 2*np.pi , size = (J,))
            X[:,n+1] = np.mod(X[:,n] + v*np.cos(theta), 1)
            Y[:,n+1] = np.mod(Y[:,n] + v*np.sin(theta), 1)
            M = L[:, n]
            for i in range(J):
                for j in range(J):
                    #if np.sqrt((X[i,n]-X[j,n])**2+(Y[i,n]-Y[j,n])**2) < delta and Computation().W(L[i,n],L[j,n]) == L[j,n]:
                    if np.sqrt((X[i,n]-X[j,n])**2+(Y[i,n]-Y[j,n])**2) < delta:
                        M[i] = Computation().W(L[i,n],L[j,n])
            L[:,n+1] = M
            for idx in range(3):
                F[idx, n+1] = np.count_nonzero(L[:, n+1] == idx)


        np.save("PPP.npy", (X,Y,L), allow_pickle = True)
        np.save("PPP_Frac.npy", F)
        pass

class Plot(Computation):
    """Class for plotting computated trajectories."""
    def color(self, idx):
        """Gives color of a labelled point.
        Inputs:
        - idx: Int - Label of point."""
        C = ["green", "red", "orange"]
        return C[idx]

    def plot(self, save = False):
        """Plots on a video the evolution of particles.
        Inputs:
        - save: Boolean. Saves the figure or not. Default: False"""

        X, Y, L = np.load("PPP.npy")
        F = np.load("PPP_Frac.npy")
        J, N = X.shape

        x = []
        y = []
        colors = []
        fig = plt.figure(figsize=(10, 5))

        def animation_func(n):
            plt.clf()
            plt.subplot(1, 2, 1)
            x = X[:,n]
            y = Y[:,n]
            colors = np.array([self.color(int(L[i,n])) for i in range(J)])
            plt.title("t = "+str(n))
            plt.xlim(0, 1)
            plt.ylim(0, 1)
            plt.scatter(x, y, c=colors, alpha=1)

            plt.subplot(1, 2, 2)
            TT = np.array(list(range(n)))
            FF = F[:, :n]
            plt.title("0 -> 2 -> 1 -> 0   ->: beats")
            for idx in range(3):
                plt.plot(TT, FF[idx,:]/J, color = self.color(idx), label=str(idx))
            plt.legend(loc="upper right")
            plt.grid()
            plt.xlabel("Iterations")
            plt.ylabel("Proportions")

        animation = FuncAnimation(fig, animation_func, interval=100, blit=False, repeat=True, frames=N)
        if save == True:
            animation.save("PPP.gif", writer="pillow")
        else:
            fig.tight_layout()
            plt.show()
        pass

