Skip to content
Extraits de code Groupes Projets

Comparer les révisions

Les modifications sont affichées comme si la révision source était fusionnée avec la révision cible. En savoir plus sur la comparaison des révisions.

Source

Sélectionner le projet cible
No results found

Cible

Sélectionner le projet cible
  • NoahMoussaoui/lsinc-1113-noah
  • NoahMoussaoui/lsinc1113
  • blegat/lsinc1113
3 résultats
Afficher les modifications
Validations sur la source (4)
### A Pluto.jl notebook ###
# v0.19.47
# v0.20.0
using Markdown
using InteractiveUtils
......@@ -14,6 +14,9 @@ macro bind(def, element)
end
end
# ╔═╡ 9f027cde-dba0-4da5-8c42-5fa79b3929d6
using Graphs, GraphPlot
# ╔═╡ f1ba3d3c-d0a5-4290-ab73-9ce34bd5e5f6
using Plots, OneHotArrays, PlutoUI
......@@ -175,10 +178,10 @@ md"`num_iters` = $(@bind num_iters Slider(1:20, default = 10, show_value = true)
md"## Kernel trick"
# ╔═╡ f71923d4-fbc9-4ce6-b5be-a00437c3651d
md"`η_lift` = $(@bind η_lift Slider(exp10.(-4:0.25:1), default=1, show_value = true))"
md"`η_lift` = $(@bind η_lift Slider(exp10.(-4:0.25:1), default=0.01, show_value = true))"
# ╔═╡ 18edd949-ce86-431a-a19b-4daf526e57a6
md"`num_iters_lift` = $(@bind num_iters_lift Slider(1:200, default=10, show_value = true))"
md"`num_iters_lift` = $(@bind num_iters_lift Slider(1:400, default=200, show_value = true))"
# ╔═╡ dc4feb58-d2cf-4a97-aaed-7f4593fc9732
md"""
......@@ -192,42 +195,211 @@ Par contre, n'importe quel nombre entre ``-1`` et ``1`` est un **subgradient** v
"""
# ╔═╡ f5749121-8e75-45de-95b9-63fff584e350
md"`η_L1` = $(@bind η_L1 Slider(exp10.(-4:0.25:1), default=1, show_value = true))"
md"`η_L1` = $(@bind η_L1 Slider(exp10.(-4:0.25:1), default=0.1, show_value = true))"
# ╔═╡ 66e36fb8-5a61-49a7-8053-911fd887b0a9
md"`num_iters_L1` = $(@bind num_iters_L1 Slider(1:200, default=10, show_value = true))"
md"`num_iters_L1` = $(@bind num_iters_L1 Slider(1:400, default=200, show_value = true))"
# ╔═╡ 613269ef-16ba-44ef-ad8e-997cc9aec1fb
# ╔═╡ db28bb45-3418-4080-a0fc-9136fc0196a5
md"""
## Reverse diff
Le désavantage de la forward differentiation, c'est qu'il faut recommencer tout le calcul pour calculer la dérivée par rapport à chaque variable. La *reverse differentiation*, aussi appelée *backpropagation*, résoud se problème en calculer la dérivée par rapport à toutes les variables en une fois!
### Chain rule
#### Exemple univarié
Commençons par un exemple univarié pour introduire le fait qu'il existe un choix dans l'ordre de la multiplication des dérivées. La liberté introduite par ce choix donne lieu à la différence entre la différentiation *forward* et *reverse*.
Supposions qu'on veuille dériver la fonction ``\tan(\cos(\sin(x)))`` pour ``x = \pi/3``. La Chain Rule nous donne:
```math
\begin{align}
(\tan(\cos(\sin(x))))'
& = \left. (\tan(x))' \right|_{x = \cos(\sin(x)))} (\cos(\sin(x))))'\\
& = \left. (\tan(x))' \right|_{x = \cos(\sin(x)))}
\left. (\cos(x))' \right|_{x = \sin(x))}
(\sin(x)))'\\
& = \frac{1}{\cos^2(\cos(\sin(x)))} (-\sin(\sin(x))) \cos(x)
\end{align}
```
La dérivée pour ``x = \pi/3`` est donc:
```math
\begin{align}
\left. (\tan(\cos(\sin(x))))' \right|_{x = \pi/3}
& =
\left. (\tan(x))' \right|_{x = \cos(\sin(\pi/3)))}
\left. (\cos(x))' \right|_{x = \sin(\pi/3))}
\left. (\sin(x)))' \right|_{x = \pi/3}\\
& = \frac{1}{\cos^2(\cos(\sin(\pi/3)))} (-\sin(\sin(\pi/3))) \cos(\pi/3)
\end{align}
```
Pour calculer ce produit de 3 nombres, on a 2 choix.
La première possibilité (qui correspond à forward diff) est de commencer par calculer le produit
```math
\frac{\partial}{\partial x} f(g(x), h(x)) = \frac{\partial f}{\partial g} \frac{\partial g}{\partial x} + \frac{\partial f}{\partial h} \frac{\partial h}{\partial x}
\begin{align}
\left. (\cos(\sin(x)))' \right|_{x = \pi/3}
& =
\left. (\cos(x))' \right|_{x = \sin(\pi/3))}
\left. (\sin(x)))' \right|_{x = \pi/3}\\
& =
(-\sin(\sin(\pi/3))) \cos(\pi/3)
\end{align}
```
puis de le multiplier avec ``\left. (\tan(x))' \right|_{x = \cos(\sin(\pi/3)))} = \frac{1}{\cos^2(\cos(\sin(\pi/3)))}``.
La deuxième possibilité (qui correspond à reverse diff) est de commencer par calculer le produit
```math
\begin{align}
\left. (\tan(\cos(x)))' \right|_{\textcolor{red}{x = \sin(\pi/3)}}
& =
\left. (\tan(x))' \right|_{\textcolor{red}{x = \cos(\sin(\pi/3)))}}
\left. (\cos(x))' \right|_{\textcolor{red}{x = \sin(\pi/3))}}\\
& = \frac{1}{\cos^2(\cos(\sin(\pi/3)))} (-\sin(\sin(\pi/3)))
\end{align}
```
puis de le multiplier avec ``\cos(\pi/3)``.
Vous remarquerez que dans l'équation ci-dessus, comme mis en évidence en rouge, les valeurs auxquelles les dérivées doivent être évaluées dépendent de ``\sin(\pi/3)``.
L'approche utilisée par reverse diff de multiplier de gauche à droite ne peut donc pas être effectuer sans prendre en compte la valeur qui doit être évaluée de droite à gauche.
Pour appliquer reverse diff, il faut donc commencer par une *forward pass* de droite à gauche qui calcule ``\sin(\pi/3)`` puis ``\cos(\sin(\pi/3))`` puis ``\tan(\cos(\sin(\pi/3)))``. On peut ensuite faire la *backward pass* qui multipliée les dérivée de gauche à droite. Afin d'être disponibles pour la backward pass, les valeurs calculées lors de la forward pass doivent être **stockées** ce qui implique un **coût mémoire**.
En revanche, comme forward diff calcule la dérivée dans le même sens que l'évaluation, les dérivées et évaluations peuvent être calculées en même temps afin de ne pas avoir besoin de stocker les évaluations. C'est effectivement ce qu'on a implémenter avec `Dual` précédemment.
Au vu de ce coût mémoire supplémentaire de reverse diff par rapport à forward diff,
ce dernier paraît préférable en pratique.
On va voir maintenant que dans le cas multivarié, dans certains cas, ce désavantage est contrebalancé par une meilleure complexité temporelle qui rend reverse diff indispensable!
#### Exemple multivarié
Prenons maintenant un example multivarié, supposons qu'on veuille calculer le gradient de la fonction ``f(g(h(x_1, x_2)))`` qui compose 3 fonctions ``f``, ``g`` et ``h``.
Le gradient est obtenu via la chain rule comme suit:
```math
\begin{align}
\frac{\partial}{\partial x_1} f(g(h(x_1, x_2)))
& = \frac{\partial f}{\partial g} \frac{\partial f}{\partial h} \frac{\partial h}{\partial x_1}\\
\frac{\partial}{\partial x_2} f(g(h(x_1, x_2)))
& = \frac{\partial f}{\partial g} \frac{\partial f}{\partial h} \frac{\partial h}{\partial x_2}\\
\nabla_{x_1, x_2} f(g(h(x_1, x_2)))
& = \begin{bmatrix}
\frac{\partial f}{\partial g} \frac{\partial f}{\partial h} \frac{\partial h}{\partial x_1} &
\frac{\partial f}{\partial g} \frac{\partial f}{\partial h} \frac{\partial h}{\partial x_2}
\end{bmatrix}\\
& =
\begin{bmatrix}
\frac{\partial f}{\partial g}
\end{bmatrix}
\begin{bmatrix}
\frac{\partial g}{\partial h}
\end{bmatrix}
\begin{bmatrix}
\frac{\partial h}{\partial x_1} &
\frac{\partial h}{\partial x_2}
\end{bmatrix}
\end{align}
```
On voit que c'est le produit de 3 matrices. Forward diff va exécuter ce produit de droite à gauche:
```math
\begin{align}
\nabla_{x_1, x_2} f(g(h(x_1, x_2)))
& =
\begin{bmatrix}
\frac{\partial f}{\partial g}
\end{bmatrix}
\begin{bmatrix}
\frac{\partial g}{\partial h}\frac{\partial h}{\partial x_1} &
\frac{\partial g}{\partial h}\frac{\partial h}{\partial x_2}
\end{bmatrix}\\
& =
\begin{bmatrix}
\frac{\partial f}{\partial g}
\end{bmatrix}
\begin{bmatrix}
\frac{\partial g}{\partial x_1} &
\frac{\partial g}{\partial x_2}
\end{bmatrix}\\
& =
\begin{bmatrix}
\frac{\partial f}{\partial g}\frac{\partial g}{\partial x_1} &
\frac{\partial f}{\partial g}\frac{\partial g}{\partial x_2}
\end{bmatrix}\\
& =
\begin{bmatrix}
\frac{\partial f}{\partial x_1} &
\frac{\partial f}{\partial x_2}
\end{bmatrix}
\end{align}
```
L'idée de reverse diff c'est d'effectuer le produit de gauche à droite:
```math
\begin{align}
\nabla_{x_1, x_2} f(g(h(x_1, x_2)))
& =
\begin{bmatrix}
\frac{\partial f}{\partial g}\frac{\partial g}{\partial h}
\end{bmatrix}
\begin{bmatrix}
\frac{\partial h}{\partial x_1} &
\frac{\partial h}{\partial x_2}
\end{bmatrix}\\
& =
\begin{bmatrix}
\frac{\partial f}{\partial h}
\end{bmatrix}
\begin{bmatrix}
\frac{\partial h}{\partial x_1} &
\frac{\partial h}{\partial x_2}
\end{bmatrix}\\
& =
\begin{bmatrix}
\frac{\partial f}{\partial x_1} &
\frac{\partial f}{\partial x_2}
\end{bmatrix}
\end{align}
```
"""
# ╔═╡ 4e1ac5fc-c684-42e1-9c99-3120021eb19a
md"""
Pour calculer ``\partial f / \partial x_1`` via forward diff, on part donc de ``\partial x_1 / \partial x_1 = 1`` et ``\partial x_2 / \partial x_1 = 0`` et on calcule ensuite ``\partial h / \partial x_1``, ``\partial g / \partial x_1`` puis ``\partial f / \partial x_1``.
Effectuer la reverse diff est un peu moins intuitif. L'idée est de partir de la dérivée du résultat par rapport à lui même ``\partial f / \partial f = 1`` et de calculer ``\partial f / \partial g`` puis ``\partial f / \partial h`` et ensuite ``\partial f / \partial x_1``. L'avantage de reverse diff c'est qu'il n'y a que la dernière étape qui est sécifique à ``x_1``. Tout jusqu'au calcul de ``\partial f / \partial h`` peut être réutilisé pour calculer ``\partial f / \partial x_2``, il n'y a plus cas multiplier ! Reverse diff est donc plus efficace pour calculer le gradient d'une fonction qui a une seul output par rapport à beaucoup de paramètres comme détaillé dans la discussion à la fin de ce notebook.
"""
# ╔═╡ 56b32132-113f-459f-b1d9-abb8f439a40b
md"""
### Forward pass : Construction de l'expression graph
Pour implémenter reverse diff, il faut construire l'expression graph pour garder en mémoire les valeurs des différentes expressions intermédiaires afin de pouvoir calculer les dérivées locales ``\partial f / \partial g`` et ``\partial g / \partial h``. Le code suivant défini un noeud de l'expression graph. Le field `derivative` correspond à la valeur de ``\partial f_{\text{final}} / \partial f_{\text{node}}`` où ``f_\text{final}`` est la dernière fonction de la composition et ``f_{\text{node}}`` est la fonction correspondant au node.
"""
# ╔═╡ 4931adf1-8771-4708-833e-d05c05884969
begin
mutable struct Node{T}
mutable struct Node
op::Union{Nothing,Symbol}
args::Vector{Node{T}}
value::T
derivative::T
args::Vector{Node}
value::Float64
derivative::Float64
end
Node(op, args, value::T) where {T} = Node(op, args, value, zero(T))
Node(value::T) where {T} = Node(nothing, Node{T}[], value)
Node(op, args, value) = Node(op, args, value, NaN)
Node(value) = Node(nothing, Node[], value)
end
# ╔═╡ 0b07b9cf-83b4-46e9-9a75-cf2cadbbb011
md"""
L'operateur overloading suivant sera sufficant pour construire l'expression graph dans le cadre de ce notebook, vous l'étendrez pendant la séance d'exercice.
"""
# ╔═╡ b814dc16-37de-45d1-9c7c-4eec45d3f956
begin
Base.zero(x::Node{T}) where {T} = Node(zero(T))
Base.zero(x::Node) = Node(0.0)
Base.:*(x::Node, y::Node) = Node(:*, [x, y], x.value * y.value)
Base.:+(x::Node, y::Node) = Node(:+, [x, y], x.value + y.value)
Base.:-(x::Node, y::Node) = Node(:-, [x, y], x.value - y.value)
Base.:/(x::Node, y::Number) = x * Node(inv(y))
Base.:^(x::Node, n::Integer) = Base.power_by_squaring(x, n)
Base.sin(x::Node) = Node(:sin, [x], sin(x.value))
Base.cos(x::Node) = Node(:cos, [x], cos(x.value))
end
# ╔═╡ 851e688f-2b30-44b7-9530-87990adee4b2
......@@ -273,20 +445,76 @@ end
# ╔═╡ b899a93f-9bec-48ce-b0ad-4e5157556a31
L1_loss(w, X, y) = sum(abs.(X * w - y)) / length(y)
# ╔═╡ cf81d0e0-4268-4ce5-82c5-8eca1e410233
md"Note: Une amélioration sera vue avec le tri topologique (voir théorie des graphes)"
# ╔═╡ 1572f901-c688-435e-81b9-d6e39bb82201
md"""
On crée les leafs correspondant aux variables ``x_1`` et ``x_2`` de valeur ``1`` et ``2`` respectivement. Les valeurs correspondent aux valeurs de ``x_1`` et ``x_2`` auxquelles on veut dériver la fonction. On a choisi 1 et 2 pour pouvoir les reconnaitre facilement dans le graphe.
"""
# ╔═╡ 86872f35-d62d-40e5-8770-4585d3b0c0d7
function topo_sort!(visited, topo, f::Node)
if !(f in visited)
push!(visited, f)
for arg in f.args
topo_sort!(visited, topo, arg)
end
push!(topo, f)
# ╔═╡ b7faa3b7-e0b6-4f55-8763-035d8fc5ac93
x_nodes = Node.([1, 2])
# ╔═╡ 194d3a68-6bed-41d9-b3ea-8cfaf4787c54
expr = cos(sin(prod(x_nodes)))
# ╔═╡ 7285c7e8-bce0-42f0-a53b-562a8a6c5894
function _nodes!(nodes, x::Node)
if !(x in keys(nodes))
nodes[x] = length(nodes) + 1
end
for arg in x.args
_nodes!(nodes, arg)
end
end
# ╔═╡ b9f0e9b6-c111-4d69-990f-02c460c8706d
function _edges(g, labels, nodes::Dict{Node}, done, x::Node)
id = nodes[x]
if done[id]
return
end
done[id] = true
if isnothing(x.op)
labels[id] = string(x.value)
else
labels[id] = "[" * String(x.op) * "] " * string(x.value)
end
for arg in x.args
add_edge!(g, id, nodes[arg])
_edges(g, labels, nodes, done, arg)
end
end
# ╔═╡ 114c048f-f619-4e15-8e5a-de852b0a1861
function graph(x::Node)
nodes = Dict{Node,Int}()
_nodes!(nodes, x)
done = falses(length(nodes))
g = DiGraph(length(nodes))
labels = Vector{String}(undef, length(nodes))
_edges(g, labels, nodes, done, x)
return g, labels
end
# ╔═╡ 7578be43-8dbe-4041-adc7-275f06057bfe
md"""
Pour le visualiser, on le converti en graphe utilisant la structure de donnée de Graphs.jl pour pouvoir utiliser `gplot`
"""
# ╔═╡ 69298293-c9fc-432f-9c3c-5da7ce710334
expr_graph, labels = graph(expr)
# ╔═╡ 9b9e5fc3-a8d0-4c42-91ae-25dd01bc7d7e
gplot(expr_graph, nodelabel = labels)
# ╔═╡ bd705ddd-0d00-41a0-aa55-e82daad4133d
md"### Backward pass : Calcul des dérivées"
# ╔═╡ d8052188-f2fa-4ad8-935f-581eea164bda
md"""
La fonction suivante propage la dérivée ``\partial f_{\text{final}} / \partial f_{\text{node}}`` à la dérivée des arguments de la fonction ``f_{\text{node}}``.
Comme les arguments peuvent être utilisés à par d'autres fonction, on somme la dérivée avec `+=`.
"""
# ╔═╡ 1e08b49d-03fe-4fb3-a8ba-3a00e1374b32
function _backward!(f::Node)
if isnothing(f.op)
......@@ -301,11 +529,32 @@ function _backward!(f::Node)
elseif f.op == :* && length(f.args) == 2
f.args[1].derivative += f.derivative * f.args[2].value
f.args[2].derivative += f.derivative * f.args[1].value
elseif f.op == :sin
f.args[].derivative += f.derivative * cos(f.args[].value)
elseif f.op == :cos
f.args[].derivative -= f.derivative * sin(f.args[].value)
else
error("Operator `$(f.op)` not supported yet")
end
end
# ╔═╡ 44442c34-e088-493a-bfd6-9c095c499100
md"""
La fonction `_backward!` ne doit être appelée que sur un noeud pour lequel `f.derivative` a déjà été calculé. Pour cela, `_backward!` doit avoir été appelé sur tous les noeuds qui représente des fonctions qui dépendent directement ou indirectement du résultat du noeud.
Pour trouver l'ordre dans lequel appeler `_backward!`, on utilise donc on tri topologique (nous reviendrons sur les tris topologique dans la partie graphe).
"""
# ╔═╡ 86872f35-d62d-40e5-8770-4585d3b0c0d7
function topo_sort!(visited, topo, f::Node)
if !(f in visited)
push!(visited, f)
for arg in f.args
topo_sort!(visited, topo, arg)
end
push!(topo, f)
end
end
# ╔═╡ 26c40cf4-9585-4762-abf4-ff77342a389f
function backward!(f::Node)
topo = typeof(f)[]
......@@ -321,6 +570,22 @@ function backward!(f::Node)
return f
end
# ╔═╡ 86fc7924-2002-4ac6-8e02-d9bf5edde9bf
backward!(expr)
# ╔═╡ d72d9c99-6280-49a2-9f7a-e9628f1069eb
md"On a maintenant l'information sur les dérivées de `x_nodes`:"
# ╔═╡ 0649437a-4198-4556-97dc-1b5cfbe45eed
x_nodes
# ╔═╡ 3928e9f7-9539-4d99-ac5b-6336eff8a523
md"""
### Comparaison avec Forward Diff dans l'exemple moon
Revenons sur l'exemple utilisé pour illustrer la forward diff et essayons de calculer la même dérivée mais à présent en utiliser reverse diff.
"""
# ╔═╡ 5ce7fbad-af38-4ff6-adca-b1991f3be455
w_nodes = Node.(w)
......@@ -335,25 +600,34 @@ function reverse_diff(loss, w, X, y)
return [w.derivative for w in w_nodes]
end
# ╔═╡ a610dc3c-803a-4489-a84b-8bff415bc0a6
md"We execute it a second time to get rid of the compilation time:"
# ╔═╡ 0e99048d-5696-43ab-8896-301f37a20a5d
md"On remarque que reverse diff est plus lent! IL y a un certain coût mémoire lorsqu'on consruit l'expression graph. Pour cette raison, si on veut calculer plusieurs dérivées consécutives pour différentes valeurs de ``x_1`` et ``x_2``, on a intérêt à garder le graphe et à uniquement changer la valeur des variables plutôt qu'à reconstruire le graphe à chaque fois qu'on change les valeurs. Alternativement, on peut essayer de condenser le graphe en exprimant les opérations sur des large matrices ou même tenseurs, c'est l'approche utilisée par pytorch ou tensorflow."
# ╔═╡ bd012d84-a79f-4043-961e-f7825b7e0d6c
md"`num_data` = $(@bind(num_data, Slider(1:100, default = 10, show_value = true)))"
md"`num_data` = $(@bind(num_data, Slider(1:100, default = 32, show_value = true)))"
# ╔═╡ 03f6d241-712d-4a45-b926-09be326c1c7d
md"`num_features` = $(@bind(num_features, Slider(1:100, default = 10, show_value = true)))"
md"`num_features` = $(@bind(num_features, Slider(1:100, default = 32, show_value = true)))"
# ╔═╡ e9507958-cefd-4208-896e-860d3e4e9d4b
md"`num_hidden` = $(@bind(num_hidden, Slider(1:100, default = 10, show_value = true)))"
md"`num_hidden` = $(@bind(num_hidden, Slider(1:100, default = 32, show_value = true)))"
# ╔═╡ 8ce88d4d-59b0-49bb-8c0a-3a2961e5fd4a
let
# ╔═╡ 8acebedb-3a97-4efd-bd3c-c267ecd3945c
mse2(W1, W2, X, y) = sum((X * W1 * W2 - y).^2 / length(y))
# ╔═╡ 32fdcdd9-785d-427b-94c3-3c65bf72e673
function bench(num_data, num_features, num_hidden)
X = rand(num_data, num_features)
W1 = rand(num_features, num_hidden)
W2 = rand(num_hidden)
y = rand(num_data)
mse(W1, W2, X, y) = sum((X * W1 * W2 - y).^2 / length(y))
@time for i in axes(W1, 1)
for j in axes(W1, 2)
mse(
mse2(
Dual.(W1, onehot(i, axes(W1, 1)) * onehot(j, axes(W1, 2))'),
W2,
X,
......@@ -361,9 +635,99 @@ let
)
end
end
expr = @time mse(Node.(W1), Node.(W2), Node.(X), Node.(y))
expr = @time mse2(Node.(W1), Node.(W2), Node.(X), Node.(y))
@time backward!(expr)
end;
return
end
# ╔═╡ 8ce88d4d-59b0-49bb-8c0a-3a2961e5fd4a
bench(num_data, num_features, num_hidden)
# ╔═╡ 613269ef-16ba-44ef-ad8e-997cc9aec1fb
md"""
### Comment choisir entre forward et reverse diff ?
Suppose that we need to differentiate a composition of functions:
``(f_n \circ f_{n-1} \circ \cdots \circ f_2 \circ f_1)(w)``.
For each function, we can compute a jacobian given the value of its input.
So, during a forward pass, we can compute all jacobians. We now just need to take the product of these jacobians:
```math
J_n J_{n-1} \cdots J_2 J_1
```
While the product of matrices is associative, its computational complexity depends on the order of the multiplications!
Let ``d_i \times d_{i - 1}`` be the dimension of ``J_i``.
#### Forward diff: from right to left
If the product is computed from right to left:
```math
\begin{align}
J_{1,2} & = J_2 J_1 && \Omega(d_2d_1d_0)\\
J_{1,3} & = J_3 J_{1,2} && \Omega(d_3d_2d_0)\\
J_{1,4} & = J_4 J_{1,3} && \Omega(d_4d_3d_0)\\
\vdots & \quad \vdots\\
J_{1,n} & = J_n J_{1,(n-1)} && \Omega(d_nd_{n-1}d_0)
\end{align}
```
we have a complexity of
```math
\Omega(\sum_{i=2}^n d_id_{i-1}d_0).
```
#### Reverse diff: from left to right
Reverse differentation corresponds to multiplying the adjoint from right to left or equivalently the original matrices from left to right.
This means computing the product in the following order:
```math
\begin{align}
J_{(n-1),n} & = J_n J_{n-1} && \Omega(d_nd_{n-1}d_{n-2})\\
J_{(n-2),n} & = J_{(n-1),n} J_{n-2} && \Omega(d_nd_{n-2}d_{n-3})\\
J_{(n-3),n} & = J_{(n-2),n} J_{n-3} && \Omega(d_nd_{n-3}d_{n-4})\\
\vdots & \quad \vdots\\
J_{1,n} & = J_{2,n} J_1 && \Omega(d_nd_1d_0)\\
\end{align}
```
We have a complexity of
```math
\Omega(\sum_{i=1}^{n-1} d_nd_id_{i-1}).
```
#### Mixed : from inward to outward
Suppose we multiply starting from some ``d_k`` where ``1 < k < n``.
We would then first compute the left side:
```math
\begin{align}
J_{k+1,k+2} & = J_{k+2} J_{k+1} && \Omega(d_{k+2}d_{k+1}d_{k})\\
J_{k+1,k+3} & = J_{k+3} J_{k+1,k+2} && \Omega(d_{k+3}d_{k+2}d_{k})\\
\vdots & \quad \vdots\\
J_{k+1,n} & = J_{n} J_{k+1,n-1} && \Omega(d_nd_{n-1}d_k)
\end{align}
```
then the right side:
```math
\begin{align}
J_{k-1,k} & = J_k J_{k-1} && \Omega(d_kd_{k-1}d_{k-2})\\
J_{k-2,k} & = J_{k-1,k} J_{k-2} && \Omega(d_kd_{k-2}d_{k-3})\\
\vdots & \quad \vdots\\
J_{1,k} & = J_{2,k} J_1 && \Omega(d_kd_1d_0)\\
\end{align}
```
and then combine both sides:
```math
J_{1,n} = J_{k+1,n} J_{1,k} \qquad \Omega(d_nd_kd_0)
```
we have a complexity of
```math
\Omega(d_nd_kd_0 + \sum_{i=1}^{k-1} d_kd_id_{i-1} + \sum_{i=k}^{n} d_id_{i-1}d_k).
```
#### Comparison
We see that we should find the minimum ``d_k`` and start from there. If the minimum is attained at ``k = n``, this corresponds mutliplying from left to right, this is reverse differentiation. If the minimum is attained at ``k = 0``, we should multiply from right to left, this is forward differentiation. Otherwise, we should start from the middle, this would mean mixing both forward and reverse diff.
What about neural networks ? In that case, ``d_0`` is equal to the number of entries in ``W_1`` added with the number of entries in ``W_2`` while ``d_n`` is ``1`` since the loss is scalar. We should therefore clearly multiply from left to right hence do reverse diff.
"""
# ╔═╡ fdd28672-5902-474a-8c87-3f6f38bcf54f
md"""
......@@ -530,6 +894,9 @@ plot_w(w_trained_L1)
# ╔═╡ 4444622b-ccfe-4867-b659-489573099f1e
@time reverse_diff(mse, w, X, y)
# ╔═╡ c1942ee2-c5af-4b2b-986d-6ad563ef27bb
@time reverse_diff(mse, w, X, y)
# ╔═╡ 0af6bce6-bc3b-438e-a57b-0c0c6586c0c5
W1 = rand(size(X, 2), num_hidden)
......@@ -567,6 +934,8 @@ cross_entropies = -log.(sum(Y_est .* Y, dims=2))
PLUTO_PROJECT_TOML_CONTENTS = """
[deps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
GraphPlot = "a2cc645c-3eea-5389-862e-a155d0052231"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
......@@ -575,10 +944,12 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
[compat]
Colors = "~0.12.11"
GraphPlot = "~0.6.0"
Graphs = "~1.9.0"
MLJBase = "~1.7.0"
OneHotArrays = "~0.2.5"
Plots = "~1.40.8"
PlutoUI = "~0.7.59"
PlutoUI = "~0.7.60"
Tables = "~1.12.0"
"""
......@@ -586,9 +957,9 @@ Tables = "~1.12.0"
PLUTO_MANIFEST_TOML_CONTENTS = """
# This file is machine-generated - editing it directly is not advised
julia_version = "1.11.0"
julia_version = "1.11.1"
manifest_format = "2.0"
project_hash = "7bb0257e84f3d9264b2fb594690718f02c77971a"
project_hash = "5c3de1419124fc7d203b55effaa59184d76d6b7e"
[[deps.AbstractPlutoDingetjes]]
deps = ["Pkg"]
......@@ -646,6 +1017,12 @@ version = "2.3.0"
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.2"
[[deps.ArnoldiMethod]]
deps = ["LinearAlgebra", "Random", "StaticArrays"]
git-tree-sha1 = "62e51b39331de8911e4a7ff6f5aaf38a5f4cc0ae"
uuid = "ec485272-7323-5ecc-a04f-4719b315124d"
version = "0.2.0"
[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
version = "1.11.0"
......@@ -694,9 +1071,9 @@ version = "0.1.9"
[[deps.Bzip2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd"
git-tree-sha1 = "8873e196c2eb87962a2048b3b8e08946535864a1"
uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0"
version = "1.0.8+1"
version = "1.0.8+2"
[[deps.CEnum]]
git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc"
......@@ -798,6 +1175,12 @@ deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.1+0"
[[deps.Compose]]
deps = ["Base64", "Colors", "DataStructures", "Dates", "IterTools", "JSON", "LinearAlgebra", "Measures", "Printf", "Random", "Requires", "Statistics", "UUIDs"]
git-tree-sha1 = "bf6570a34c850f99407b494757f5d7ad233a7257"
uuid = "a81c6b42-2e10-5240-aca2-a61377ecd94b"
version = "0.9.5"
[[deps.CompositionsBase]]
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b"
......@@ -1025,15 +1408,15 @@ version = "0.1.6"
[[deps.GR]]
deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Qt6Wayland_jll", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"]
git-tree-sha1 = "629693584cef594c3f6f99e76e7a7ad17e60e8d5"
git-tree-sha1 = "ee28ddcd5517d54e417182fec3886e7412d3926f"
uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
version = "0.73.7"
version = "0.73.8"
[[deps.GR_jll]]
deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"]
git-tree-sha1 = "a8863b69c2a0859f2c2c87ebdc4c6712e88bdf0d"
git-tree-sha1 = "f31929b9e67066bee48eec8b03c0df47d31a74b3"
uuid = "d2c73de3-f751-5644-a686-071e5b155ba9"
version = "0.73.7+0"
version = "0.73.8+0"
[[deps.Gettext_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"]
......@@ -1047,12 +1430,24 @@ git-tree-sha1 = "674ff0db93fffcd11a3573986e550d66cd4fd71f"
uuid = "7746bdde-850d-59dc-9ae8-88ece973131d"
version = "2.80.5+0"
[[deps.GraphPlot]]
deps = ["ArnoldiMethod", "Colors", "Compose", "DelimitedFiles", "Graphs", "LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "f76a7a0f10af6ce7f227b7a921bfe351f628ed45"
uuid = "a2cc645c-3eea-5389-862e-a155d0052231"
version = "0.6.0"
[[deps.Graphite2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "344bf40dcab1073aca04aa0df4fb092f920e4011"
uuid = "3b182d85-2403-5c21-9c21-1e1f0cc25472"
version = "1.3.14+0"
[[deps.Graphs]]
deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"]
git-tree-sha1 = "899050ace26649433ef1af25bc17a815b3db52b7"
uuid = "86223c79-3864-5bf0-83f7-82e725a168b6"
version = "1.9.0"
[[deps.Grisu]]
git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2"
uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe"
......@@ -1094,6 +1489,11 @@ git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770"
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
version = "0.2.5"
[[deps.Inflate]]
git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d"
uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9"
version = "0.1.5"
[[deps.InitialValues]]
git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3"
uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c"
......@@ -1124,6 +1524,11 @@ git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.2.2"
[[deps.IterTools]]
git-tree-sha1 = "42d5f897009e7ff2cf88db414a389e5ed1bdd023"
uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
version = "1.10.0"
[[deps.IteratorInterfaceExtensions]]
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
uuid = "82899510-4779-5014-852e-03e436cf321d"
......@@ -1137,9 +1542,9 @@ version = "0.1.8"
[[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"]
git-tree-sha1 = "f389674c99bfcde17dc57454011aa44d5a260a40"
git-tree-sha1 = "be3dc50a92e5a386872a493a10050136d4703f9b"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.6.0"
version = "1.6.1"
[[deps.JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
......@@ -1182,10 +1587,10 @@ uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d"
version = "3.100.2+0"
[[deps.LERC_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434"
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "36bdbc52f13a7d1dcb0f3cd694e01677a515655b"
uuid = "88015f11-f218-50d7-93a8-a6af411a945d"
version = "3.0.0+1"
version = "4.0.0+0"
[[deps.LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"]
......@@ -1218,9 +1623,9 @@ uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac"
version = "2.10.2+1"
[[deps.LaTeXStrings]]
git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec"
git-tree-sha1 = "dda21b8cbd6a6c40d9d02a73230f9d70fed6918c"
uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
version = "1.3.1"
version = "1.4.0"
[[deps.Latexify]]
deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"]
......@@ -1316,9 +1721,9 @@ version = "2.40.1+0"
[[deps.Libtiff_jll]]
deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"]
git-tree-sha1 = "2da088d113af58221c52828a80378e16be7d037a"
git-tree-sha1 = "b404131d06f7886402758c9ce2214b636eb4d54a"
uuid = "89763e89-9b03-5906-acba-b20f662cd828"
version = "4.5.1+1"
version = "4.7.0+0"
[[deps.Libuuid_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
......@@ -1583,10 +1988,10 @@ uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a"
version = "3.2.0"
[[deps.PlotUtils]]
deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"]
git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5"
deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "StableRNGs", "Statistics"]
git-tree-sha1 = "650a022b2ce86c7dcfbdecf00f78afeeb20e5655"
uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043"
version = "1.4.1"
version = "1.4.2"
[[deps.Plots]]
deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"]
......@@ -1610,9 +2015,9 @@ version = "1.40.8"
[[deps.PlutoUI]]
deps = ["AbstractPlutoDingetjes", "Base64", "ColorTypes", "Dates", "FixedPointNumbers", "Hyperscript", "HypertextLiteral", "IOCapture", "InteractiveUtils", "JSON", "Logging", "MIMEs", "Markdown", "Random", "Reexport", "URIs", "UUIDs"]
git-tree-sha1 = "ab55ee1510ad2af0ff674dbcced5e94921f867a9"
git-tree-sha1 = "eba4810d5e6a01f612b948c9fa94f905b49087b0"
uuid = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
version = "0.7.59"
version = "0.7.60"
[[deps.PrecompileTools]]
deps = ["Preferences"]
......@@ -1771,6 +2176,11 @@ git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac"
uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46"
version = "1.1.1"
[[deps.SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
version = "1.11.0"
[[deps.ShowCases]]
git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5"
uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
......@@ -1824,6 +2234,12 @@ git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5"
uuid = "171d559e-b47b-412a-8079-5efa626c420e"
version = "0.1.15"
[[deps.StableRNGs]]
deps = ["Random"]
git-tree-sha1 = "83e6cce8324d49dfaf9ef059227f91ed4441a8e5"
uuid = "860ef19b-820b-49d6-a774-d7a799459cd3"
version = "1.0.2"
[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50"
......@@ -2398,22 +2814,46 @@ version = "1.4.1+1"
# ╠═91e5840a-8a18-4d36-8fce-510af4c7dcf2
# ╠═4a807c44-f33e-4a58-99b3-261c392be891
# ╠═b5f3c2b1-c86a-48e4-973b-ee639d784936
# ╟─613269ef-16ba-44ef-ad8e-997cc9aec1fb
# ╟─db28bb45-3418-4080-a0fc-9136fc0196a5
# ╟─4e1ac5fc-c684-42e1-9c99-3120021eb19a
# ╟─56b32132-113f-459f-b1d9-abb8f439a40b
# ╠═4931adf1-8771-4708-833e-d05c05884969
# ╟─0b07b9cf-83b4-46e9-9a75-cf2cadbbb011
# ╠═b814dc16-37de-45d1-9c7c-4eec45d3f956
# ╟─cf81d0e0-4268-4ce5-82c5-8eca1e410233
# ╠═86872f35-d62d-40e5-8770-4585d3b0c0d7
# ╟─1572f901-c688-435e-81b9-d6e39bb82201
# ╠═b7faa3b7-e0b6-4f55-8763-035d8fc5ac93
# ╠═194d3a68-6bed-41d9-b3ea-8cfaf4787c54
# ╟─7285c7e8-bce0-42f0-a53b-562a8a6c5894
# ╟─b9f0e9b6-c111-4d69-990f-02c460c8706d
# ╟─114c048f-f619-4e15-8e5a-de852b0a1861
# ╟─7578be43-8dbe-4041-adc7-275f06057bfe
# ╠═69298293-c9fc-432f-9c3c-5da7ce710334
# ╠═9b9e5fc3-a8d0-4c42-91ae-25dd01bc7d7e
# ╟─bd705ddd-0d00-41a0-aa55-e82daad4133d
# ╟─d8052188-f2fa-4ad8-935f-581eea164bda
# ╠═1e08b49d-03fe-4fb3-a8ba-3a00e1374b32
# ╟─44442c34-e088-493a-bfd6-9c095c499100
# ╠═26c40cf4-9585-4762-abf4-ff77342a389f
# ╠═86872f35-d62d-40e5-8770-4585d3b0c0d7
# ╠═86fc7924-2002-4ac6-8e02-d9bf5edde9bf
# ╟─d72d9c99-6280-49a2-9f7a-e9628f1069eb
# ╠═0649437a-4198-4556-97dc-1b5cfbe45eed
# ╟─3928e9f7-9539-4d99-ac5b-6336eff8a523
# ╠═5ce7fbad-af38-4ff6-adca-b1991f3be455
# ╠═0dd8e1bf-f8c0-4183-a5eb-13eeb5316a7b
# ╠═7a320b75-c104-43d1-9129-f7f53910f5bc
# ╠═c1b73208-f917-4823-bf45-d896f4ee59e0
# ╠═4444622b-ccfe-4867-b659-489573099f1e
# ╟─a610dc3c-803a-4489-a84b-8bff415bc0a6
# ╠═c1942ee2-c5af-4b2b-986d-6ad563ef27bb
# ╟─0e99048d-5696-43ab-8896-301f37a20a5d
# ╟─bd012d84-a79f-4043-961e-f7825b7e0d6c
# ╟─03f6d241-712d-4a45-b926-09be326c1c7d
# ╟─e9507958-cefd-4208-896e-860d3e4e9d4b
# ╠═8acebedb-3a97-4efd-bd3c-c267ecd3945c
# ╠═32fdcdd9-785d-427b-94c3-3c65bf72e673
# ╠═8ce88d4d-59b0-49bb-8c0a-3a2961e5fd4a
# ╟─613269ef-16ba-44ef-ad8e-997cc9aec1fb
# ╟─fdd28672-5902-474a-8c87-3f6f38bcf54f
# ╟─54697e82-ee8c-4b65-a633-b29a47fac722
# ╠═0af6bce6-bc3b-438e-a57b-0c0c6586c0c5
......@@ -2435,6 +2875,7 @@ version = "1.4.1+1"
# ╟─c1da4130-5936-499f-bb9b-574e01136eca
# ╟─b16f6225-1949-4b6d-a4b0-c5c230eb4c7f
# ╠═dad5cba4-9bc6-47c3-a932-f2cc496b0f40
# ╠═9f027cde-dba0-4da5-8c42-5fa79b3929d6
# ╠═f1ba3d3c-d0a5-4290-ab73-9ce34bd5e5f6
# ╟─00000000-0000-0000-0000-000000000001
# ╟─00000000-0000-0000-0000-000000000002
Ce diff est replié.