#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Nov  4 21:00:47 2021

@author: maximebouchereau
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from scipy.linalg import block_diag
import sys

# Résolution de l'équation de la chaleur en deux dimensions d'espace: érosion naturelle avec terme source

# Paramètres physiques [à ajuster]

Lx = 10000 # Longueur du domaine en x
Ly = 10000 # Longueur du domaine en y

kencaissant = 1e-7 # Conductivité de l'encaissant 
rhoencaissant = 1 # Masse volumique de l'encaissant
cencaissant = 1 # Capacité thermique massique de l'encaissant

Lxdyke = 1000 # Longueur du dyke en x
Lydyke = 5000 # Longueur du dyke en y
kdyke = 1e-10 # Conductivité du dyke
rhodyke = 1 # Masse volumique du dyke
cdyke = 1 # Capacité thermique massique du dyke

U = 1E-11 # Vitesse d'uplift

tau = 25*3.156E+13 # Durée de la simulation (en s - égal à 1 Gannées)



# Paramètres de la simulation numérique [à ajuster]

J = 50 # Nombre de subdivisions en espace selon chaque direction
N = 200 # Nombre de subdivisions en temps



# Paramètres calculés

xx = np.linspace(0,Lx,J+1) # Discrétisation de l'intervalle des x
yy = np.linspace(0,Ly,J+1) # Discrétisation de l'intervalle des y

Ddyke = kdyke/(rhodyke*cdyke) # Coefficient de diffusion du dyke
Dencaissant = kencaissant/(rhoencaissant*cencaissant) # Coefficient de diffusion de l'encaissant

MatDiff = np.zeros((J+1,J+1)) # Donne le coefficient de diffusion thermique en fonction de l'endroit sur lequel on se trouve

for i in range(J+1):
    for j in range(J+1):
        if abs(xx[i]-Lx/2) <= Lxdyke/2 and yy[j] >= Ly-Lydyke:# Test de vérification si on se trouve sur le dyke
            MatDiff[i,j] = Ddyke
        else: # C'est que l'on se trouve sure l'encaissant
            MatDiff[i,j] = Dencaissant

MatDiffVect = MatDiff.T.reshape((J+1)*(J+1),) # Transforme MatDiff en vecteur (changement de dimension)
MatDiffNum = np.diag(MatDiffVect) # Transforme le vecteur précédent en matrice diagonale utilisable dans le schéma numérique


hx = Lx/J # Pas de subdivision en espace selon x
hy = Ly/J # Pas de subdivision en espace selon y
ht = tau/N # Pas de subdivision en temps



# Construction des matrices impliquées dans le schéma numérique

# A est la matrice qui va imiter le laplacien en 1D selon x (avec conditions de Neumann)
A = np.diag(np.ones(J),-1) + np.diag(np.ones(J),1) - np.diag(2*np.ones(J+1),0)
A[0,0] = -1
A[J,J] = -1

# B est la matrice qui va imiter le laplacien en 2D (avec conditions de Neumann)
F = (1/hx**2)*A-(2/hy**2)*np.eye(J+1,J+1)
Fbis = (1/hx**2)*A-(1/hy**2)*np.eye(J+1,J+1)

B = Fbis

print("   ")
print("   ")
print("Construction de la matrice diagonale par blocs...")
for j in range(J-1):
    sys.stdout.write("\r%d   "%int(100*(j+1)/(J-1))+"%")
    sys.stdout.flush()
    B = block_diag(B,F) # Construction de la "diagonale" de B par blocs

B = block_diag(B,Fbis) # On termine par le dernier bloc qui est différent

B = B + np.diag((1/hy**2)*np.ones(J*(J+1)),J+1) + np.diag((1/hy**2)*np.ones(J*(J+1)),-J-1)
#B = -(1/hx**2)*np.eye(B.shape[0],B.shape[1])

# M est la matrice qui sera utilisée à chaque itération dans le schéma numérique

I = np.eye((J+1)**2,(J+1)**2) # Matrice identité
M = np.linalg.inv(I-ht*MatDiffNum@B) # Inversion à cause du schéma implicite
#M = I + ht*MatDiffNum@B # Matrice d'itération pour schéma explicite


# Calcul de la solution approchée de l'équation de la chaleur

H = np.zeros(((J+1)**2,N+1)) # Collection de vecteurs contenant les solutions aux temps t_n

# Construction de la condition initiale et du terme source

Hinit = np.zeros((J+1,J+1)) # Altitude initiale du domaine nulle

UUsource = np.zeros((J+1,J+1)) # Construction du terme source

print("   ")
print("   ")
print("Construction de la matrice associée au terme source...")
for i in range(J+1):
    for j in range(J+1):
        sys.stdout.write("\r%d   "%int(100*(i*(J+1)+j+1)/((J+1)**2))+"%")
        sys.stdout.flush()
        if yy[j] >= Ly/2: # Test de vérification si on se trouve au nord de la faille
            UUsource[i,j] = U
        else: # C'est qu'on est au sud de la faille
            UUsource[i,j] = 0


H[:,0] = Hinit.T.reshape((J+1)**2) # Vecteur associé à la condition initiale redimensionné en vecteur colonne
Usource = UUsource.T.reshape((J+1)**2) # Vecteur source redimensionné en vecteur colonne

# Calcul de la solution à chaque temps

print("   ")
print("   ")
print("Calcul de la solution...")
for n in range(N):
    sys.stdout.write("\r%d   "%int(100*(n+1)/N)+"%")
    sys.stdout.flush()
    H[:,n+1] = M@(H[:,n]+ht*Usource)
    #H[:,n+1] = M@H[:,n]+ht*Usource



# Calcul de la date à laquelle l'équilibre est atteint

neq = 0 # Compte le nombre de points où la température a dépassé les 600°C dans le domaine discrétisé
toleq = 1e-12 # Tolérance (seuil) en dessous duquel on considère que l'équilibre a lieu

while (np.max(abs(H[:,neq+1]-np.min(H[:,neq+1])-H[:,neq]+np.min(H[:,neq])))) > toleq*ht and neq<=N-2:
    neq = neq+1

print("   ")
print("   ")
print("L'équilibre a lieu au bout de",0.1*round(neq*ht/(85400*365.25*100000)),"millions d'années") # On retire Ldyke*Hdyke, car ce domaine n'est pas contitué d'encaissant



# Détermination de l'altitude maximale et minimale au moment de l'équilibre

print("Altitude minimale au moment de l'équilibre:",format(np.min(H[:,neq]),'.1f'),"m")
print("Altitude maximale au moment de l'équilibre:",format(np.max(H[:,neq]),'.1f'),"m")



# Tracé de la solution au bout d'un million d'années

plt.figure()
plt.title("Altitude à la fin de la simulation [m]")
plt.xlabel("x")
plt.ylabel("y")
plt.imshow(np.flipud(H[:,N].reshape(J+1,J+1)),cmap="jet",aspect="equal",extent=(0,Lx,0,Ly),vmin=np.min(H[:,N]),vmax=np.max(H[:,N]))
plt.colorbar(format='%.f')
plt.show()

# Tracé de la solution au bout d'un million d'années en 3D

xxx,yyy = np.meshgrid(xx,yy,indexing='ij')
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
surf = ax.plot_surface(xxx,yyy,np.flipud(H[:,N].reshape(J+1,J+1)),cmap='jet')
plt.xlabel("x")
plt.ylabel("y")
plt.title("Altitude à la fin de la simulation [m]")
plt.show()



# Animation du tracé de la solution au cours du temps (bonus)

deltat = 5000/N # Durée entre deux frames (en ms)

fig=plt.figure()
im = plt.imshow(np.flipud(H[:,0].reshape(J+1,J+1)),cmap="jet",aspect="equal",extent=(0,Lx,0,Ly),vmin=np.min(H),vmax=np.max(H))
plt.title("Evolution de l'altitude [m]")
plt.xlabel("x")
plt.ylabel("y")
plt.colorbar(format='%.f')

def animate(n):
    im.set_array(np.flipud(H[:,n].reshape(J+1,J+1)))
    return [im]

anim = animation.FuncAnimation(fig, animate, frames=N, blit=True, interval=deltat, repeat=True)
#%matplotlib qt
plt.show()



# Animation du tracé de la solution au cours du temps en 3D(bonus++)

plt.rcParams["figure.figsize"] = [7.50, 3.50]
plt.rcParams["figure.autolayout"] = True

deltat = 5000/N # Durée entre deux frames (en ms)

mx, my = np.meshgrid(xx, yy)
HH = np.zeros((J + 1, J + 1, N))


for n in range(N):
   HH[:, :, n] = np.flipud(H[:,n].reshape(J+1,J+1))

def animate(n, HH, plot):
   plot[0].remove()
   plot[0] = ax.plot_surface(mx, my, HH[:, :, n], cmap="jet",vmin=np.min(H),vmax=np.max(H))

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

plot = [ax.plot_surface(mx, my, HH[:, :, 0], color='0.75', rstride=1, cstride=1)]

ax.set_zlim(np.min(H), np.max(H))
ani = animation.FuncAnimation(fig, animate, N, fargs=(HH, plot), interval=deltat,repeat=True)

plt.xlabel("x")
plt.ylabel("y")
plt.title("Evolution de l'altitude [m]")
ax.view_init(elev=30,azim=45)

plt.show()

























