From c020df02847e771397d70175fde4124f44e78c2a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Beno=C3=AEt=20Legat?= <benoit.legat@gmail.com>
Date: Tue, 5 Nov 2024 14:17:06 +0100
Subject: [PATCH] update 4_autodiff

---
 4_autodiff.jl | 57 ++++++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 47 insertions(+), 10 deletions(-)

diff --git a/4_autodiff.jl b/4_autodiff.jl
index 90c540e..caee906 100644
--- a/4_autodiff.jl
+++ b/4_autodiff.jl
@@ -15,7 +15,7 @@ macro bind(def, element)
 end
 
 # ╔═╡ 9f027cde-dba0-4da5-8c42-5fa79b3929d6
-using Graphs, GraphPlot
+using Graphs, GraphPlot, Printf
 
 # ╔═╡ f1ba3d3c-d0a5-4290-ab73-9ce34bd5e5f6
 using Plots, OneHotArrays, PlutoUI
@@ -270,7 +270,10 @@ En revanche, comme forward diff calcule la dérivée dans le même sens que l'é
 Au vu de ce coût mémoire supplémentaire de reverse diff par rapport à forward diff,
 ce dernier paraît préférable en pratique.
 On va voir maintenant que dans le cas multivarié, dans certains cas, ce désavantage est contrebalancé par une meilleure complexité temporelle qui rend reverse diff indispensable!
+"""
 
+# ╔═╡ 494fc7d7-c622-41d1-91d8-3dc1fbd2f244
+md"""
 #### Exemple multivarié
 
 Prenons maintenant un example multivarié, supposons qu'on veuille calculer le gradient de la fonction ``f(g(h(x_1, x_2)))`` qui compose 3 fonctions ``f``, ``g`` et ``h``.
@@ -278,13 +281,13 @@ Le gradient est obtenu via la chain rule comme suit:
 ```math
 \begin{align}
 \frac{\partial}{\partial x_1} f(g(h(x_1, x_2)))
-& = \frac{\partial f}{\partial g} \frac{\partial f}{\partial h} \frac{\partial h}{\partial x_1}\\
+& = \frac{\partial f}{\partial g} \frac{\partial g}{\partial h} \frac{\partial h}{\partial x_1}\\
 \frac{\partial}{\partial x_2} f(g(h(x_1, x_2)))
-& = \frac{\partial f}{\partial g} \frac{\partial f}{\partial h} \frac{\partial h}{\partial x_2}\\
+& = \frac{\partial f}{\partial g} \frac{\partial g}{\partial h} \frac{\partial h}{\partial x_2}\\
 \nabla_{x_1, x_2} f(g(h(x_1, x_2)))
 & = \begin{bmatrix}
-\frac{\partial f}{\partial g} \frac{\partial f}{\partial h} \frac{\partial h}{\partial x_1} &
-\frac{\partial f}{\partial g} \frac{\partial f}{\partial h} \frac{\partial h}{\partial x_2}
+\frac{\partial f}{\partial g} \frac{\partial g}{\partial h} \frac{\partial h}{\partial x_1} &
+\frac{\partial f}{\partial g} \frac{\partial g}{\partial h} \frac{\partial h}{\partial x_2}
 \end{bmatrix}\\
 & =
 \begin{bmatrix}
@@ -473,10 +476,9 @@ function _edges(g, labels, nodes::Dict{Node}, done, x::Node)
 		return
 	end
 	done[id] = true
-	if isnothing(x.op)
-		labels[id] = string(x.value)
-	else
-		labels[id] = "[" * String(x.op) * "] " * string(x.value)
+	labels[id] = @sprintf "%.2f" x.value
+	if !isnothing(x.op)
+		labels[id] = "[" * String(x.op) * "] " * labels[id]
 	end
 	for arg in x.args
 		add_edge!(g, id, nodes[arg])
@@ -506,6 +508,34 @@ expr_graph, labels = graph(expr)
 # ╔═╡ 9b9e5fc3-a8d0-4c42-91ae-25dd01bc7d7e
 gplot(expr_graph, nodelabel = labels)
 
+# ╔═╡ 54af5dab-e669-45eb-b6b5-44c46f7258b4
+md"""
+#### Combinaison des dérivées
+
+Que faire si plusieurs expressions dépendent d'une même variable.
+Considérons l'exemple ``f(x) = \sin(x)\cos(x)`` qui correspond à ``f(g,h) = gh``, ``g(x) = \sin(x)`` et ``h(x) = \cos(x)``.
+La chain rule donne
+```math
+\begin{align}
+  f'(x) & = \frac{\partial f}{\partial g}g'(x) + \frac{\partial f}{\partial h}h'(x)
+\end{align}
+```
+Une fois la valeur ``\partial f / \partial g`` calculée, on peut la multiplier par ``g'(x)`` pour avoir la première partie partie de ``f'(x)``. Idem pour ``h``. Ces deux contributions seront calculée séparément lors de la backward pass sur le noeud ``g`` et ``h``. On voit par la formule de la chain rule que ces deux contributions doivent être sommées.
+Lors de la backward pass, on initialise donc toutes les dérivées à 0. Pour chaque contribution, on ajoute la dérivée avec `+=`. On s'assure ensuite qu'on ne procède pas à la backward pass sur un noeud avant qu'il ait fini d'accumuler les contributions de toutes les expressions qui en dépendent via un *tri topologique*.
+"""
+
+# ╔═╡ 842df050-e0be-441d-b84e-7f0575eac227
+x_node = Node(1)
+
+# ╔═╡ 36fcfd5e-c521-42fc-96cd-93787e657627
+sin_cos = sin(x_node) * cos(x_node)
+
+# ╔═╡ f1da5c61-9862-44f6-84e4-96217736e1cb
+sin_cos_graph, sin_cos_labels = graph(sin_cos)
+
+# ╔═╡ 363af09c-3cc9-440d-9b70-1da8f6a70913
+gplot(sin_cos_graph, nodelabel = sin_cos_labels)
+
 # ╔═╡ bd705ddd-0d00-41a0-aa55-e82daad4133d
 md"### Backward pass : Calcul des dérivées"
 
@@ -940,6 +970,7 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
 OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
 Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
 PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
+Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
 Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
 
 [compat]
@@ -959,7 +990,7 @@ PLUTO_MANIFEST_TOML_CONTENTS = """
 
 julia_version = "1.11.1"
 manifest_format = "2.0"
-project_hash = "5c3de1419124fc7d203b55effaa59184d76d6b7e"
+project_hash = "647c126887a2ffbf11df74e3cf03bb0463ccc843"
 
 [[deps.AbstractPlutoDingetjes]]
 deps = ["Pkg"]
@@ -2815,6 +2846,7 @@ version = "1.4.1+1"
 # ╠═4a807c44-f33e-4a58-99b3-261c392be891
 # ╠═b5f3c2b1-c86a-48e4-973b-ee639d784936
 # ╟─db28bb45-3418-4080-a0fc-9136fc0196a5
+# ╟─494fc7d7-c622-41d1-91d8-3dc1fbd2f244
 # ╟─4e1ac5fc-c684-42e1-9c99-3120021eb19a
 # ╟─56b32132-113f-459f-b1d9-abb8f439a40b
 # ╠═4931adf1-8771-4708-833e-d05c05884969
@@ -2829,6 +2861,11 @@ version = "1.4.1+1"
 # ╟─7578be43-8dbe-4041-adc7-275f06057bfe
 # ╠═69298293-c9fc-432f-9c3c-5da7ce710334
 # ╠═9b9e5fc3-a8d0-4c42-91ae-25dd01bc7d7e
+# ╟─54af5dab-e669-45eb-b6b5-44c46f7258b4
+# ╠═842df050-e0be-441d-b84e-7f0575eac227
+# ╠═36fcfd5e-c521-42fc-96cd-93787e657627
+# ╠═f1da5c61-9862-44f6-84e4-96217736e1cb
+# ╠═363af09c-3cc9-440d-9b70-1da8f6a70913
 # ╟─bd705ddd-0d00-41a0-aa55-e82daad4133d
 # ╟─d8052188-f2fa-4ad8-935f-581eea164bda
 # ╠═1e08b49d-03fe-4fb3-a8ba-3a00e1374b32
-- 
GitLab