Le désavantage de la forward differentiation, c'est qu'il faut recommencer tout le calcul pour calculer la dérivée par rapport à chaque variable. La *reverse differentiation*, aussi appelée *backpropagation*, résoud se problème en calculer la dérivée par rapport à toutes les variables en une fois!
### Chain rule
#### Exemple univarié
Commençons par un exemple univarié pour introduire le fait qu'il existe un choix dans l'ordre de la multiplication des dérivées. La liberté introduite par ce choix donne lieu à la différence entre la différentiation *forward* et *reverse*.
Supposions qu'on veuille dériver la fonction ``\tan(\cos(\sin(x)))`` pour ``x = \pi/3``. La Chain Rule nous donne:
Vous remarquerez que dans l'équation ci-dessus, comme mis en évidence en rouge, les valeurs auxquelles les dérivées doivent être évaluées dépendent de ``\sin(\pi/3)``.
L'approche utilisée par reverse diff de multiplier de gauche à droite ne peut donc pas être effectuer sans prendre en compte la valeur qui doit être évaluée de droite à gauche.
Pour appliquer reverse diff, il faut donc commencer par une *forward pass* de droite à gauche qui calcule ``\sin(\pi/3)`` puis ``\cos(\sin(\pi/3))`` puis ``\tan(\cos(\sin(\pi/3)))``. On peut ensuite faire la *backward pass* qui multipliée les dérivée de gauche à droite. Afin d'être disponibles pour la backward pass, les valeurs calculées lors de la forward pass doivent être **stockées** ce qui implique un **coût mémoire**.
En revanche, comme forward diff calcule la dérivée dans le même sens que l'évaluation, les dérivées et évaluations peuvent être calculées en même temps afin de ne pas avoir besoin de stocker les évaluations. C'est effectivement ce qu'on a implémenter avec `Dual` précédemment.
Au vu de ce coût mémoire supplémentaire de reverse diff par rapport à forward diff,
ce dernier paraît préférable en pratique.
On va voir maintenant que dans le cas multivarié, dans certains cas, ce désavantage est contrebalancé par une meilleure complexité temporelle qui rend reverse diff indispensable!
#### Exemple multivarié
Prenons maintenant un example multivarié, supposons qu'on veuille calculer le gradient de la fonction ``f(g(h(x_1, x_2)))`` qui compose 3 fonctions ``f``, ``g`` et ``h``.
Le gradient est obtenu via la chain rule comme suit:
Pour calculer ``\partial f / \partial x_1`` via forward diff, on part donc de ``\partial x_1 / \partial x_1 = 1`` et ``\partial x_2 / \partial x_1 = 0`` et on calcule ensuite ``\partial h / \partial x_1``, ``\partial g / \partial x_1`` puis ``\partial f / \partial x_1``.
Effectuer la reverse diff est un peu moins intuitif. L'idée est de partir de la dérivée du résultat par rapport à lui même ``\partial f / \partial f = 1`` et de calculer ``\partial f / \partial g`` puis ``\partial f / \partial h`` et ensuite ``\partial f / \partial x_1``. L'avantage de reverse diff c'est qu'il n'y a que la dernière étape qui est sécifique à ``x_1``. Tout jusqu'au calcul de ``\partial f / \partial h`` peut être réutilisé pour calculer ``\partial f / \partial x_2``, il n'y a plus cas multiplier ! Reverse diff est donc plus efficace pour calculer le gradient d'une fonction qui a une seul output par rapport à beaucoup de paramètres comme détaillé dans la discussion à la fin de ce notebook.
"""
# ╔═╡ 56b32132-113f-459f-b1d9-abb8f439a40b
md"""
### Forward pass : Construction de l'expression graph
Pour implémenter reverse diff, il faut construire l'expression graph pour garder en mémoire les valeurs des différentes expressions intermédiaires afin de pouvoir calculer les dérivées locales ``\partial f / \partial g`` et ``\partial g / \partial h``. Le code suivant défini un noeud de l'expression graph. Le field `derivative` correspond à la valeur de ``\partial f_{\text{final}} / \partial f_{\text{node}}`` où ``f_\text{final}`` est la dernière fonction de la composition et ``f_{\text{node}}`` est la fonction correspondant au node.
L'operateur overloading suivant sera sufficant pour construire l'expression graph dans le cadre de ce notebook, vous l'étendrez pendant la séance d'exercice.
md"Note: Une amélioration sera vue avec le tri topologique (voir théorie des graphes)"
# ╔═╡ 1572f901-c688-435e-81b9-d6e39bb82201
md"""
On crée les leafs correspondant aux variables ``x_1`` et ``x_2`` de valeur ``1`` et ``2`` respectivement. Les valeurs correspondent aux valeurs de ``x_1`` et ``x_2`` auxquelles on veut dériver la fonction. On a choisi 1 et 2 pour pouvoir les reconnaitre facilement dans le graphe.
"""
# ╔═╡ 86872f35-d62d-40e5-8770-4585d3b0c0d7
function topo_sort!(visited,topo,f::Node)
if!(finvisited)
push!(visited,f)
forarginf.args
topo_sort!(visited,topo,arg)
end
push!(topo,f)
# ╔═╡ b7faa3b7-e0b6-4f55-8763-035d8fc5ac93
x_nodes=Node.([1,2])
# ╔═╡ 194d3a68-6bed-41d9-b3ea-8cfaf4787c54
expr=cos(sin(prod(x_nodes)))
# ╔═╡ 7285c7e8-bce0-42f0-a53b-562a8a6c5894
function _nodes!(nodes,x::Node)
if!(xinkeys(nodes))
nodes[x]=length(nodes)+1
end
forarginx.args
_nodes!(nodes,arg)
end
end
# ╔═╡ b9f0e9b6-c111-4d69-990f-02c460c8706d
function _edges(g,labels,nodes::Dict{Node},done,x::Node)
id=nodes[x]
ifdone[id]
return
end
done[id]=true
ifisnothing(x.op)
labels[id]=string(x.value)
else
labels[id]="["*String(x.op)*"] "*string(x.value)
end
forarginx.args
add_edge!(g,id,nodes[arg])
_edges(g,labels,nodes,done,arg)
end
end
# ╔═╡ 114c048f-f619-4e15-8e5a-de852b0a1861
function graph(x::Node)
nodes=Dict{Node,Int}()
_nodes!(nodes,x)
done=falses(length(nodes))
g=DiGraph(length(nodes))
labels=Vector{String}(undef,length(nodes))
_edges(g,labels,nodes,done,x)
returng,labels
end
# ╔═╡ 7578be43-8dbe-4041-adc7-275f06057bfe
md"""
Pour le visualiser, on le converti en graphe utilisant la structure de donnée de Graphs.jl pour pouvoir utiliser `gplot`
"""
# ╔═╡ 69298293-c9fc-432f-9c3c-5da7ce710334
expr_graph,labels=graph(expr)
# ╔═╡ 9b9e5fc3-a8d0-4c42-91ae-25dd01bc7d7e
gplot(expr_graph,nodelabel=labels)
# ╔═╡ bd705ddd-0d00-41a0-aa55-e82daad4133d
md"### Backward pass : Calcul des dérivées"
# ╔═╡ d8052188-f2fa-4ad8-935f-581eea164bda
md"""
La fonction suivante propage la dérivée ``\partial f_{\text{final}} / \partial f_{\text{node}}`` à la dérivée des arguments de la fonction ``f_{\text{node}}``.
Comme les arguments peuvent être utilisés à par d'autres fonction, on somme la dérivée avec `+=`.
"""
# ╔═╡ 1e08b49d-03fe-4fb3-a8ba-3a00e1374b32
function _backward!(f::Node)
ifisnothing(f.op)
...
...
@@ -301,11 +529,32 @@ function _backward!(f::Node)
La fonction `_backward!` ne doit être appelée que sur un noeud pour lequel `f.derivative` a déjà été calculé. Pour cela, `_backward!` doit avoir été appelé sur tous les noeuds qui représente des fonctions qui dépendent directement ou indirectement du résultat du noeud.
Pour trouver l'ordre dans lequel appeler `_backward!`, on utilise donc on tri topologique (nous reviendrons sur les tris topologique dans la partie graphe).
"""
# ╔═╡ 86872f35-d62d-40e5-8770-4585d3b0c0d7
function topo_sort!(visited,topo,f::Node)
if!(finvisited)
push!(visited,f)
forarginf.args
topo_sort!(visited,topo,arg)
end
push!(topo,f)
end
end
# ╔═╡ 26c40cf4-9585-4762-abf4-ff77342a389f
function backward!(f::Node)
topo=typeof(f)[]
...
...
@@ -321,6 +570,22 @@ function backward!(f::Node)
returnf
end
# ╔═╡ 86fc7924-2002-4ac6-8e02-d9bf5edde9bf
backward!(expr)
# ╔═╡ d72d9c99-6280-49a2-9f7a-e9628f1069eb
md"On a maintenant l'information sur les dérivées de `x_nodes`:"
# ╔═╡ 0649437a-4198-4556-97dc-1b5cfbe45eed
x_nodes
# ╔═╡ 3928e9f7-9539-4d99-ac5b-6336eff8a523
md"""
### Comparaison avec Forward Diff dans l'exemple moon
Revenons sur l'exemple utilisé pour illustrer la forward diff et essayons de calculer la même dérivée mais à présent en utiliser reverse diff.
"""
# ╔═╡ 5ce7fbad-af38-4ff6-adca-b1991f3be455
w_nodes=Node.(w)
...
...
@@ -335,25 +600,34 @@ function reverse_diff(loss, w, X, y)
return[w.derivativeforwinw_nodes]
end
# ╔═╡ a610dc3c-803a-4489-a84b-8bff415bc0a6
md"We execute it a second time to get rid of the compilation time:"
# ╔═╡ 0e99048d-5696-43ab-8896-301f37a20a5d
md"On remarque que reverse diff est plus lent! IL y a un certain coût mémoire lorsqu'on consruit l'expression graph. Pour cette raison, si on veut calculer plusieurs dérivées consécutives pour différentes valeurs de ``x_1`` et ``x_2``, on a intérêt à garder le graphe et à uniquement changer la valeur des variables plutôt qu'à reconstruire le graphe à chaque fois qu'on change les valeurs. Alternativement, on peut essayer de condenser le graphe en exprimant les opérations sur des large matrices ou même tenseurs, c'est l'approche utilisée par pytorch ou tensorflow."
We see that we should find the minimum ``d_k`` and start from there. If the minimum is attained at ``k = n``, this corresponds mutliplying from left to right, this is reverse differentiation. If the minimum is attained at ``k = 0``, we should multiply from right to left, this is forward differentiation. Otherwise, we should start from the middle, this would mean mixing both forward and reverse diff.
What about neural networks ? In that case, ``d_0`` is equal to the number of entries in ``W_1`` added with the number of entries in ``W_2`` while ``d_n`` is ``1`` since the loss is scalar. We should therefore clearly multiply from left to right hence do reverse diff.
"""
# ╔═╡ fdd28672-5902-474a-8c87-3f6f38bcf54f
md"""
...
...
@@ -530,6 +894,9 @@ plot_w(w_trained_L1)
# ╔═╡ 4444622b-ccfe-4867-b659-489573099f1e
@timereverse_diff(mse,w,X,y)
# ╔═╡ c1942ee2-c5af-4b2b-986d-6ad563ef27bb
@timereverse_diff(mse,w,X,y)
# ╔═╡ 0af6bce6-bc3b-438e-a57b-0c0c6586c0c5
W1=rand(size(X,2),num_hidden)
...
...
@@ -567,6 +934,8 @@ cross_entropies = -log.(sum(Y_est .* Y, dims=2))