# File imports

import warnings
warnings.filterwarnings('ignore')
warnings.warn('DelftStack')
warnings.warn('Do not show this message')

import numpy as np

import math
import scipy.optimize

import matplotlib.pyplot as plt
from matplotlib import animation

import time
import datetime
from datetime import datetime as dtime

# Program

print(100*"-")
print("""Modelling the N-body problem in physics""")
print(100*"-")

# Parameters

N_simul = 5
T_simul = 1
h_simul = 0.01
R = 1
deltat = 0.01

print(" ")
print("Parameters:")
print(" ")
print(" - Number of bodies:" , N_simul)
print(" - Time for simulation:" , T_simul)
print(" - Step size:" , h_simul)
print(" - Half-length of the simulation's domain for initial data:" , R)
print(" - Interval between two frames print in animation:" , deltat , "ms")

# Simulation

def F(X , reg = 1e-2):
    """Gravitationnal force vector field.
    Input:
    - X: Array of shape (d,N), where N is the number of bodies involved and d is the dimension of the problem - Space variable
    - reg: Float: Regularization parameter, for the numerical stability. Default: 1e-2"""

    FF = np.zeros_like(X)
    d,N = X.shape
    Delta = sum([np.abs(X[k,:].reshape(1,N).T - X[k,:].reshape(1,N))**2 + np.eye(N,N) for k in range(d)])
    for k in range(d):
        dist = X[k,:].reshape(1,N).T - X[k,:].reshape(1,N)
        FF[k, :] = np.sum(dist/((Delta)**1.5+reg) , axis = 0)
    return FF

def Run(N = N_simul , T = T_simul , h = h_simul , regul = 1e-2 , save_fig = False):
    """Runs a simulation of the N-body problem.
    Inputs:
    - N: Int - Number of involved bodies. Default: N_simul
    - T: Float - Time for simulation. Default: T_simul
    - h: Float - Step size. Default: h_simul
    - regul: Float: Regularization parameter, for the numerical stability. Default: 1e-2
    - save_fig: Boolean - Saves the figure or not. Default: False"""

    #Q_0 = np.random.uniform(low=-R , high=R , size=(2,N))

    r_0 = np.random.uniform(low=0 , high=R , size=(1,N))
    theta_0 = np.random.uniform(low=0 , high=2*np.pi , size=(1,N))

    Q_0 = np.concatenate((r_0*np.cos(theta_0),r_0*np.sin(theta_0)) , axis=0)

    P_0 = 0*np.random.uniform(low=-R , high=R , size=(2,N))

    TT = np.arange(0,T+h,h)

    Q , P = np.zeros((TT.size,Q_0.shape[0],Q_0.shape[1])) , np.zeros((TT.size,P_0.shape[0],P_0.shape[1]))
    Q[0,:,:] , P[0,:,:] = Q_0 , P_0

    pow = max([int(np.log10(TT.size)), 3])
    pow = min([pow, 6])

    start_time_integration = time.time()
    for n in range(TT.size-1):
        end_time_integration = start_time_integration + (TT.size / (n + 1)) * (time.time() - start_time_integration)
        end_time_integration = datetime.datetime.fromtimestamp(int(end_time_integration)).strftime(' %Y-%m-%d %H:%M:%S')
        print(" Loading :  {} % \r".format(str(int(10 ** (pow) * (n + 1) / TT.size) / 10 ** (pow - 2)).rjust(3)), " Estimated time for ending : " + end_time_integration, " - ", end="")

        # Q[n+1,:,:] , P[n+1,:,:] = Q[n,:,:] + h*P[n,:,:] , P[n,:,:] + h*F(Q[n,:,:])

        Q[n + 1, :, :] = Q[n, :, :] + h * P[n, :, :]
        P[n + 1, :, :] = P[n, :, :] + h * F(Q[n + 1, :, :], regul)


    plt.figure()
    for nn in range(N):
        plt.plot(Q[:,0,nn],Q[:,1,nn])
    ax = plt.axes()
    ax.set_facecolor("black")
    plt.title("Evolution of "+str(N)+" bodies")
    plt.show()


    # fig, ax = plt.subplots(1,1)
    #
    # def animate(n):
    #     ax.clear()
    #     #ax.set_facecolor("black")
    #     ax.set_xlim(np.min(Q[:,0,:]), np.max(Q[:,0,:]))
    #     ax.set_ylim(np.min(Q[:,1,:]), np.max(Q[:,1,:]))
    #     print(Q[n,0,0], Q[n,1,0])
    #
    #     point1, = ax.plot(Q[n,0,0], Q[n,1,0], marker='x', color='green' , markersize = 20000)
    #     #point2, = ax.plot(Q[n,0,1], Q[n,1,1], marker='.', color='green')
    #     #point3, = ax.plot(Q[n,0,2], Q[n,1,2], marker='.', color='green')
    #     return point1, # point2 , point3 ,
    #
    # #anim = animation.FuncAnimation(fig, animate, frames=TT.size, blit=True, interval=deltat, repeat=True)
    # anim = animation.FuncAnimation(fig, animate, interval=deltat, blit=True, frames=TT.size)
    # if 1==0:
    #     anim.save("lines.gif", dpi=300, writer=PillowWriter(fps=25))
    # fig.tight_layout()
    # plt.show()

    # Animation

    fig = plt.figure()
    for n in range(TT.size):
        plt.draw()
        ax = plt.axes()
        ax.set_facecolor("black")
        ax.set_title("t=" + str(round((n * h), 2)))
        #ax.set_xlim(np.min(Q[:, 0, :]), np.max(Q[:, 0, :]))
        #ax.set_ylim(np.min(Q[:,1,:]), np.max(Q[:,1,:]))
        ax.set_xlim(-R, R)
        ax.set_ylim(-R, R)
        ax.set_aspect('equal', 'box')
        #plot(Z)
        plt.grid()
        plt.scatter(Q[n,0,:], Q[n,1,:], color="white", s=20 , marker="o")
        plt.pause(deltat)
        if save_fig == True:
            plt.savefig("N_body_problem_N="+str(N)+"_T="+str(T)+"_h="+str(h)+"_n="+str(n)+".png")
        fig.clear()

    pass

