import warnings
from tkinter.tix import TCL_TIMER_EVENTS

warnings.filterwarnings('ignore')
warnings.warn('DelftStack')
warnings.warn('Do not show this message')

import torch
import torch.optim as optim
import torch.nn as nn
import copy

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.cm as cm
from matplotlib.animation import FuncAnimation
from scipy.integrate import solve_ivp
from scipy.optimize import fixed_point

from itertools import product
import statistics

import sys
import time
import datetime
from datetime import datetime as dtime

### SNAKES AND LADDERS GAME MODELLING WITH MARKOV CHAINS

params = {'Game_Length': 10, 'Players': 20, 'Snakes_Ladders': 5}

print(150 * "_")
print(" ")
print("  ### SNAKES AND LADDERS GAME MODELLING WITH MARKOV CHAINS ### ")
print(150 * "_")

print(" ")
print("   - Parameters:")
print("      > Length of the game:", params['Game_Length'], " - Number of positions: ", params['Game_Length'] ** 2)
print("      > Number of players:", params['Players'])
print("      > Number of snakes:", params['Snakes_Ladders'])
print("      > Number of ladders:", params['Snakes_Ladders'])
print(150 * "_")

def Game(d, S):
    """Generates a game set for Snakes and Ladders game.

    Inputs:
    - d: Int - Length of the game set. The game se has d x d positions.
    - S: Number of snakes and ladders. The game has S snakes and S ladders.

    Saves [with parameters of the game: d and S]:
    - An array containing departure and arrivals positions with snakes
    and Ladders.
    - Iteration matrix & Iteration vector for game. Gives positions at next step.
    - Probability transition matrix. For study of theoretical game [Markov chain]."""

    print(150 * "_")
    print(" ")
    print("   - Generation of a Game set")
    print(150 * "_")
    print(" ")

    # Snakes_Ladders = np.arange(0, 4 * S, 1, dtype=np.int64).reshape(2, 2 * S)
    Snakes_Ladders = np.zeros((2, 2 * S), dtype=np.int64)
    It_Matrix, It_Vect = np.zeros((d ** 2, d ** 2)), np.zeros(d ** 2)
    Prob_Matrix = np.zeros((d ** 2, d ** 2))

    # Generation of Snakes and Ladders
    print("   - Generation of Snakes and Ladders")
    while np.unique(Snakes_Ladders).shape[0] < 4 * S or (1 in Snakes_Ladders) or (d ** 2 in Snakes_Ladders):
        # Generation of Snakes
        for s in range(S):
            print("      > Snakes: " + str(s + 1) + " / " + str(S), end="   \r")
            i, j = np.random.randint(1, d), np.random.randint(0, d)
            a = d * i + j + 1
            ii, jj = np.random.randint(0, i), np.random.randint(0, d)
            b = d * ii + jj + 1
            Snakes_Ladders[:, s] = np.array([a, b])

        # Generation of Ladders
        for s in range(S):
            print("      > Ladders: " + str(s + 1) + " / " + str(S), end="   \r")
            i, j = np.random.randint(0, d-1), np.random.randint(0, d)
            a = d * i + j + 1
            ii, jj = np.random.randint(i+1, d), np.random.randint(0, d)
            b = d * ii + jj + 1
            Snakes_Ladders[:, S + s] = np.array([a, b])

    # Generation of Iteration matrix
    print("   - Generation of Iteration Matrix")
    for k in range(d ** 2):
        print("      > Number: " + str(k + 1) + " / " + str(d ** 2), end="   \r")
        if k + 1 in Snakes_Ladders:
            i, j = np.where(Snakes_Ladders == k + 1)
            i, j = i[0], j[0]
            if i == 0:
                kk = Snakes_Ladders[i+1, j]
                It_Matrix[kk - 1, k] = 1
                It_Vect[k] = kk
            elif i == 1:
                It_Matrix[k, k] = 1
                It_Vect[k] = k + 1
        else:
            It_Matrix[k, k] = 1
            It_Vect[k] = k + 1

    # Generation of Probability transition matrix
    print("   - Generation of Probability Transition Matrix")
    for k in range(1, d ** 2 + 1):
        for s in range(1, 7):
            Prob_Matrix[int(It_Vect[int(min(k + s, d ** 2)) - 1]) - 1, k - 1] += 1/6

    # path = "C:\Documents\Implementaitons\TempProjects\Documents_Snakes_and_Ladders\"
    name_Snakes_Ladders = r"Documents_Snakes_and_Ladders\Game_Snakes_and_Ladders_size=" + str(d ** 2) + "_Snakes_Ladders=" + str(S) + "_Snakes_and_Ladders.npy"
    name_It_Vector = r"Documents_Snakes_and_Ladders\Game_Snakes_and_Ladders_size=" + str(d ** 2) + "_Snakes_Ladders=" + str(S) + "_Iteration_Vector.npy"
    name_Prob_Matrix = r"Documents_Snakes_and_Ladders\Game_Snakes_and_Ladders_size=" + str(d ** 2) + "_Snakes_Ladders=" + str(S) + "_Prob_Matrix.npy"
    np.save(name_Snakes_Ladders, Snakes_Ladders)
    np.save(name_It_Vector, It_Vect)
    np.save(name_Prob_Matrix, Prob_Matrix)

    print(150 * "_")

    return None

def Run(d, S, N):
    """Runs a game with selected parameters.

    Inputs:
    - d: Int - Game Length. Total size: d ** 2.
    - S: Int - Number of Snakes and Ladders. 2*S snakes and ladders.
    - N: Int - Number of players.

    Saves [with parameters of the game: d, S and N]:
    - An array containing Positions of players.
    - An array containing theoretical distribution of players [deduced with Markov chains].
    - An array containing evolution of number of players who have finished the game."""

    print(150 * "_")
    print(" ")
    print("   - Game running")
    print(150 * "_")
    print(" ")

    name_Snakes_Ladders = r"Documents_Snakes_and_Ladders\Game_Snakes_and_Ladders_size=" + str(d ** 2) + "_Snakes_Ladders=" + str(S) + "_Snakes_and_Ladders.npy"
    name_It_Vector = r"Documents_Snakes_and_Ladders\Game_Snakes_and_Ladders_size=" + str(d ** 2) + "_Snakes_Ladders=" + str(S) + "_Iteration_Vector.npy"
    name_Prob_Matrix = r"Documents_Snakes_and_Ladders\Game_Snakes_and_Ladders_size=" + str(d ** 2) + "_Snakes_Ladders=" + str(S) + "_Prob_Matrix.npy"

    Snakes_Ladders = np.load(name_Snakes_Ladders)
    It_Vector = np.load(name_It_Vector)
    Prob_Matrix = np.load(name_Prob_Matrix)

    # Initialization for players
    Pos, Pos_Time = np.ones((N, 1), dtype = np.int64), np.ones((N, 1), dtype = np.int64)
    Finished_Players = np.array([0])

    # Initialization for theoretical game
    Dist, Dist_Time = np.eye((d ** 2))[:, 0].reshape(d ** 2, 1), np.eye((d ** 2))[:, 0].reshape(d ** 2, 1)

    # Game run
    kk = 0
    while np.min(Pos) < d ** 2:
        # Update of players positions
        # Effects of dices [1 -> 6]
        Pos = Pos + np.random.randint(low=1, high=7, size=(N, 1), dtype=np.int64)
        Pos = np.minimum(Pos, d ** 2)
        Pos_Time = np.concatenate((Pos_Time, Pos), axis=1)
        # Effects of Snakes and Ladders
        for n in range(N):
            Pos[n] = int(It_Vector[Pos[n] - 1])
        Pos_Time = np.concatenate((Pos_Time, Pos), axis=1)
        Finished_Players = np.concatenate((Finished_Players, np.array([np.where(Pos > d ** 2 - 1)[0].size])), axis = 0)

        # Update of theoretical distribution
        Dist = Prob_Matrix @ Dist
        Dist_Time = np.concatenate((Dist_Time, Dist), axis=1)

        print("      > Finished players: " + str(int(1000 * np.where(Pos > d ** 2 - 1)[0].size / N) / 10) + " %  -  Finished players [Theory]: " + str(int(1000 * Dist[-1]) / 10) + " %  ", end="      \r")

    plt.figure()
    plt.plot(np.arange(0, Dist_Time.shape[1]), Dist_Time[-1, :], color = "red", label = "Theory")
    plt.plot(np.arange(0, Finished_Players.size), Finished_Players / N, color = "green", label = "Players")
    plt.legend()
    plt.grid()
    plt.show()

    name_Theory_Distribution = r"Documents_Snakes_and_Ladders\Game_Snakes_and_Ladders_size=" + str(d ** 2) + "_Snakes_Ladders=" + str(S) + "_Players=" + str(N) + "_Theory_distribution.npy"
    name_Players_Positions = r"Documents_Snakes_and_Ladders\Game_Snakes_and_Ladders_size=" + str(d ** 2) + "_Snakes_Ladders=" + str(S) + "_Players=" + str(N) + "_Players_Positions.npy"
    name_Finished_Players = r"Documents_Snakes_and_Ladders\Game_Snakes_and_Ladders_size=" + str(d ** 2) + "_Snakes_Ladders=" + str(S) + "_Players=" + str(N) + "_Finished_Players.npy"
    np.save(name_Theory_Distribution, Dist_Time)
    np.save(name_Players_Positions, Pos_Time)
    np.save(name_Finished_Players, Finished_Players)

    print(150 * "_")
    return None

def Print(d, S, N, save = False):
    """Prints the evolution of a Game with specified parameters.

    Inputs:
    - d: Int - Game Length. Total size: d ** 2.
    - S: Int - Number of Snakes and Ladders. 2*S snakes and ladders.
    - N: Int - Number of players.
    - save: Boolean - Saves the frames or not.
    """

    print(150 * "_")
    print(" ")
    print("   - Game plot")
    print(150 * "_")
    print(" ")

    # Load arrays
    name_Snake_Ladders = r"Documents_Snakes_and_Ladders\Game_Snakes_and_Ladders_size=" + str(d ** 2) + "_Snakes_Ladders=" + str(S) + "_Snakes_and_Ladders.npy"
    name_Theory_Distribution = r"Documents_Snakes_and_Ladders\Game_Snakes_and_Ladders_size=" + str(d ** 2) + "_Snakes_Ladders=" + str(S) + "_Players=" + str(N) + "_Theory_distribution.npy"
    name_Players_Positions = r"Documents_Snakes_and_Ladders\Game_Snakes_and_Ladders_size=" + str(d ** 2) + "_Snakes_Ladders=" + str(S) + "_Players=" + str(N) + "_Players_Positions.npy"
    name_Finished_Players = r"Documents_Snakes_and_Ladders\Game_Snakes_and_Ladders_size=" + str(d ** 2) + "_Snakes_Ladders=" + str(S) + "_Players=" + str(N) + "_Finished_Players.npy"

    Snakes_Ladders = np.load(name_Snake_Ladders)
    Dist_Time = np.load(name_Theory_Distribution)
    Pos_Time = np.load(name_Players_Positions)
    Finished_Players = np.load(name_Finished_Players)

    # Parmeters extraction
    N_Frames = Pos_Time.shape[1]

    # Coordinates of each player
    Coord_Players = np.random.uniform(low=-0.4, high=0.4, size=(2, N))

    # Coordinate function
    def Coordinate(k, d):
        """Gives coordinate (kx,ky) in a Game set for
        a game of length d.

        Inputs:
        - k: Int - Number for coordinate computation.
        - d: Int. Length of the game."""

        ix, iy = (k-1) % d, (k - 1) // d

        if iy % 2 == 0:
            jx = ix
            jy = iy
        elif iy % 2 == 1:
            jx = d - ix - 1
            jy = iy

        # return 0.5 * ((-1) ** (iy + 1) + 1) * d + ((-1) ** iy) * ix, iy
        return jx, jy

    # Figure plot
    fig = plt.figure(figsize=(12, 3))
    def animation_func(n):
        print("      > Frame: " + str(n) + " / " + str(N_Frames), end = "     \r")

        plt.clf()

        plt.subplot(1, 3, 1)
        for i in range(d+1):
            for j in range(d+1):
                plt.plot([0, d], [j, j], color = "black")
                plt.plot([i, i], [0, d], color = "black")

        for s in range(S):
            ix, iy = Coordinate(int(Snakes_Ladders[0, s]), d)
            jx, jy = Coordinate(int(Snakes_Ladders[1, s]), d)
            plt.arrow(ix + 0.5, iy + 0.5, jx-ix, jy-iy, width=0.05, color="red")
        for s in range(S):
            ix, iy = Coordinate(int(Snakes_Ladders[0, s + S]), d)
            jx, jy = Coordinate(int(Snakes_Ladders[1, s + S]), d)
            plt.arrow(ix + 0.5, iy + 0.5, jx-ix, jy-iy, width=0.05, color="green")

        for nn in range(N):
            ix, iy = Coordinate(int(Pos_Time[nn, n]), d)
            ix, iy = ix + 0.5 + Coord_Players[0, nn], iy + 0.5 + Coord_Players[1, nn]
            plt.scatter([ix, 1.001*ix], [iy, 1.001*iy], marker="o", s=10)
        ax = plt.gca()
        ax.set_aspect('equal')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        plt.xlim(-0.5, d + 0.5)
        plt.ylim(-0.5, d + 0.5)
        plt.title("Game evolution")

        plt.subplot(1, 3, 2)
        plt.plot(np.arange(0, n//2), 100 * Dist_Time[-1, :n//2], color="red", label="Theory")
        plt.plot(np.arange(0, n//2), 100 * Finished_Players[:n//2] / N, color="green", label="Players")
        plt.title("Finished players [%]")
        plt.legend(loc="upper left")
        plt.grid()

        plt.subplot(1, 3, 3)
        plt.plot(np.arange(1, d ** 2 + 1), Dist_Time[:, n//2], color="red", label = "Theory")
        plt.hist(Pos_Time[:, n], bins=np.arange(1, d ** 2 + 1), density=True, color="green", label="Players")
        plt.grid()
        plt.legend(loc="upper left")
        plt.title("Distribution")

    animation = FuncAnimation(fig, animation_func, interval=100, blit=False, repeat=True, frames=N_Frames)
    if save == True:
        animation.save("Snakes_and_Ladders_Length=" + str(d) + "_Snakes_Ladders=" + str(S) + "_Players=" + str(N) + ".gif", writer="pillow")
    else:
        fig.tight_layout()
        plt.show()

    print(150 * "_")
    return None

