#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 15 23:52:44 2021

@author: maximebouchereau
"""

import numpy as np
import math as mt
import matplotlib.pyplot as plt
from scipy.special import erfinv

# Modélisation du reffroidissement de la lithosphère océanique en surface



# Paramètres physiques [à ajuster]

H = 1.2E5 # Hauteur étudiée
Tsurf = 0 # Température de surface
Tmanteau = 1350 # Température du manteau
K = 1E-6 # Coefficient de diffusion thermique

annees = 50000000 # Durée de simulation, en années



# Paramètres de simulation numérique [à ajuster]

J = 600 # Nombre de subdiviions en espace
N = 500 # Nombre de subdiviions en temps



# Paramètres calculés

tau  = 86400*365.25*annees # Durée de la simulation, en s

hz = H/J # Pas de discrétisation en espace
ht = tau/N # Pas de discrétisation en temps

domaine = np.linspace(-H,0,J+1)



# Construction des matrices impliquées dans le schéma numérique

# A est la matrice qui va imiter la laplacien 1D avec conditions de Dirichlet
A = np.diag(np.ones(J-2),-1) + np.diag(np.ones(J-2),1) - np.diag(2*np.ones(J-1),0)

# M est la matrice qui sera utilisée à chaque itération dans le schéma numérique

I = np.eye(J-1,J-1) # Matrice identité
M = np.linalg.inv(I-K*(ht/hz**2)*A) # Inversion à cause du schéma implicite



# Calcul de la solution approchée de l'équation de la chaleur

U = np.zeros((J+1,N+1)) # Collection de vecteurs contenant les solutions aux temps t_n (attention aux doubles parenthèses !!!)

# Construction de la condition initiale
U[:,0] = Tmanteau*np.ones(J+1)

#  Construction d'un vecteur "source" tenant compte des conditions de Dirichlet
F = np.zeros(J-1,)
F[0] = (K/hz**2)*Tmanteau

# Calcul de la solution à chaque temps
for n in range(N):
    U[1:J,n+1] = M@(U[1:J,n]+ht*F)
    U[0,n+1] = Tmanteau # Condition de Dirichlet non nulle prise en compte
    


# Calcul de la durée nécéssaire pour que la lithosphère atteigne une épaisseur de 80 km

n80 = 0
ep = 0

while ep < 80000 and n80<N:
    ep = list(np.heaviside(U[:,n80]-1250,0)).count(0)/(J+1)*H # Compte la proportion de points situés au-dessous de 1250°C x H
    n80 = n80 + 1
    
print("Il faut",int(n80*ht/(86400*365.25)),"années pour que la lithosphère atteigne 80 km d'épaisseur")


# Calcul via la solution analytique

p = 80000 # Epaisseur de la lithosphère étudiée (80 km, conversion en m)
Tlith = 1250 # Température à en dessous de laquelle la lithosphère est définie (en °C)
t = (1/(4*K))*(p/erfinv(Tlith/Tmanteau))**2

print("   ")

print("Avec la solution analytique, on trouve:",int(t/(86400*365.25)),"années pour que la lithosphère atteigne 80 km d'épaisseur")




# Solution analytique

def T(z,t):
    "Solution analytique"
    temp = Tmanteau*mt.erf(-z/np.sqrt(4*K*t))
    return temp


Uexact_050 = np.array([T(z,tau/2) for z in domaine])
Uexact_100 = np.array([T(z,tau) for z in domaine])



# Tracé de la solution à différents temps en 1D

plt.figure()
plt.title("Profil de la température dans la lithosphère en "+str(annees)+" années")
plt.xlabel("Profondeur [m]")
plt.ylabel("Température [°C]")
plt.plot(domaine,U[:,0],label="Température initiale",color='red')
plt.plot(domaine,U[:,N//2],label="Température à la moitié de la simulation",color="orange")
plt.plot(domaine,Uexact_050,label="Température à la moitié de la simulation (solution exacte)",color="orange",linestyle="--")
plt.plot(domaine,U[:,N],label="Température à la fin de la simulation",color="yellow")
plt.plot(domaine,Uexact_100,label="Température à la fin de la simulation (solution exacte)",color="yellow",linestyle="--")
plt.legend()
plt.grid()
plt.show()



# Tracé de la solution au cours du temps (bonus)

plt.figure()
plt.title("Evolution de la température au cours du temps [°C]")
plt.xlabel("temps [années]")
plt.ylabel("Profondeur [km]")
plt.imshow(U,cmap="jet",aspect="auto",extent=(0,annees,-H/1000,0))
plt.colorbar()
plt.show()


