from numpy import *
import numpy.linalg as la
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib as mpl
mpl.rcParams['animation.ffmpeg_path'] = r'C:\\CompiledPrograms\\ffmpeg-20190507-e25bddf-win64-static\\bin\\ffmpeg.exe'

colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

AU = 149597870.700e3 # (m / au) mètres dans 1 année-lumière
DAY = 3600 * 24 # (s / jour) secondes dans 1 jour
AU_PER_DAY = AU / DAY # (m / s) / (au / jour) = (m / au) / (s / jour)

# masses des corps (kg)
m = array([
    1.989e30,
    1.898e27,
    5.683e26,
], dtype=float64)
# nombre de corps
n = len(m)
# constante de gravitation
G = 6.67408e-11
# positions initiales [ [x0, y0], [x1, y1], ... ] (17 avril) (m)
q0 = array([
    [0, 0],
    [3.625838752290225E+00, -3.517639747097937E+00],
    [5.935311726536074E+00, -7.998765340509357E+00],
], dtype=float64) * AU
# vitesses (17 avril) (m/s)
v0 = array([
    [0, 0],
    [5.161600958019233E-03, 5.773107422075889E-03],
    [4.169003257611070E-03, 3.310760061243794E-03],
], dtype=float64) * AU_PER_DAY
# SOURCE: https://ssd.jpl.nasa.gov/horizons.cgi
# impulsions initiales
p0 = v0 * m[:, None]

def dHdq(q, p):
    return G * array([
        sum(array([
            m[i]*m[j] * (q[i] - q[j]) / la.norm(q[i] - q[j]) ** 3
            for j in range(n) if j != i]), axis=0)
    for i in range(n)])

def dHdp(q, p):
    # m[:, None] changes the shape of the array so that the divison happens per-row and not per-column
    return p / m[:, None]

def euler_symplectique(dt, q, p):
    p -= dt * dHdq(q, p)
    q += dt * dHdp(q, p)
    return q, p

def heun(dt, q, p):
    # points intermédiaires
    q2 = q + dt * dHdp(q, p)
    p2 = p - dt * dHdq(q, p)
    # points réels
    return (
        q + dt/2 * (dHdp(q, p) + dHdp(q2, p2)),
        p - dt/2 * (dHdq(q, p) + dHdq(q2, p2)),
    )

def stormer_verlet(dt, q, p):
    p -= dt/2 * dHdq(q, p)
    q += dt * dHdp(q, p)
    p -= dt/2 * dHdq(q, p)
    return q, p

INTEGRATORS = [
    heun,
    euler_symplectique,
    stormer_verlet,
]

def energy(q, p):
    tot = 0
    tot += sum(1/2 * p**2 / m[:, None])
    for i in range(n):
        for j in range(i+1, n):
            tot -= G * m[i] * m[j] / la.norm(q[i] - q[j])
    return tot

dt = DAY * 30
q = [q0.copy() for _ in range(len(INTEGRATORS))]
p = [p0.copy() for _ in range(len(INTEGRATORS))]

fig = plt.figure()

PLOT_SIZE = 10 * AU
TEMPS_TOTAL = DAY * 365 * 400
FRAMES = int(TEMPS_TOTAL / dt)

time_values = dt * arange(FRAMES, dtype=float64) / DAY / 365 # années
energy_values = zeros((len(INTEGRATORS), FRAMES))


def frame(i):
    fig.clear()
    for j in range(len(INTEGRATORS)):
        qLocal = q[j]
        pLocal = p[j]
        # plot des valeurs actuelles
        plt.scatter(qLocal[:,0], qLocal[:,1], color=colors[j], label=INTEGRATORS[j].__name__)
        plt.quiver(qLocal[:,0], qLocal[:,1], pLocal[:,0]/m, pLocal[:,1]/m, color=colors[j])
        # calcul de l'énergie pour le graphique
        energy_values[j][i] = energy(qLocal, pLocal)
        # intégration
        qLocal, pLocal = INTEGRATORS[j](dt, qLocal, pLocal)
        # on stocke les valeurs pour l'itération suivante :-)
        q[j] = qLocal
        p[j] = pLocal

    plt.xlim(-PLOT_SIZE, PLOT_SIZE)
    plt.ylim(-PLOT_SIZE, PLOT_SIZE)
    plt.text(0.05, 0.05, "%8.2f ans" % time_values[i], transform=plt.gca().transAxes)
    plt.legend()

    # affichage du pourcentage parce que c'est quand-même long à calculer
    if i % 25 == 0:
        print("Calcul: %.2f %%" % (i / FRAMES * 100))


# Plot animation while integrating
fps = 50
anim = animation.FuncAnimation(fig, frame, frames=FRAMES)
animWriter = animation.FFMpegWriter(fps)
anim.save("animation.mp4", writer=animWriter)

# Plot
plt.clf()
for j in range(len(INTEGRATORS)):
    plt.plot(time_values, energy_values[j] / energy_values[j][0], label=INTEGRATORS[j].__name__)
plt.title("Energie totale relative en fonction du temps")
plt.legend()
plt.tight_layout()
plt.savefig("energie.pdf")