From c020df02847e771397d70175fde4124f44e78c2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= <benoit.legat@gmail.com> Date: Tue, 5 Nov 2024 14:17:06 +0100 Subject: [PATCH] update 4_autodiff --- 4_autodiff.jl | 57 ++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/4_autodiff.jl b/4_autodiff.jl index 90c540e..caee906 100644 --- a/4_autodiff.jl +++ b/4_autodiff.jl @@ -15,7 +15,7 @@ macro bind(def, element) end # â•”â•â•¡ 9f027cde-dba0-4da5-8c42-5fa79b3929d6 -using Graphs, GraphPlot +using Graphs, GraphPlot, Printf # â•”â•â•¡ f1ba3d3c-d0a5-4290-ab73-9ce34bd5e5f6 using Plots, OneHotArrays, PlutoUI @@ -270,7 +270,10 @@ En revanche, comme forward diff calcule la dérivée dans le même sens que l'é Au vu de ce coût mémoire supplémentaire de reverse diff par rapport à forward diff, ce dernier paraît préférable en pratique. On va voir maintenant que dans le cas multivarié, dans certains cas, ce désavantage est contrebalancé par une meilleure complexité temporelle qui rend reverse diff indispensable! +""" +# â•”â•â•¡ 494fc7d7-c622-41d1-91d8-3dc1fbd2f244 +md""" #### Exemple multivarié Prenons maintenant un example multivarié, supposons qu'on veuille calculer le gradient de la fonction ``f(g(h(x_1, x_2)))`` qui compose 3 fonctions ``f``, ``g`` et ``h``. @@ -278,13 +281,13 @@ Le gradient est obtenu via la chain rule comme suit: ```math \begin{align} \frac{\partial}{\partial x_1} f(g(h(x_1, x_2))) -& = \frac{\partial f}{\partial g} \frac{\partial f}{\partial h} \frac{\partial h}{\partial x_1}\\ +& = \frac{\partial f}{\partial g} \frac{\partial g}{\partial h} \frac{\partial h}{\partial x_1}\\ \frac{\partial}{\partial x_2} f(g(h(x_1, x_2))) -& = \frac{\partial f}{\partial g} \frac{\partial f}{\partial h} \frac{\partial h}{\partial x_2}\\ +& = \frac{\partial f}{\partial g} \frac{\partial g}{\partial h} \frac{\partial h}{\partial x_2}\\ \nabla_{x_1, x_2} f(g(h(x_1, x_2))) & = \begin{bmatrix} -\frac{\partial f}{\partial g} \frac{\partial f}{\partial h} \frac{\partial h}{\partial x_1} & -\frac{\partial f}{\partial g} \frac{\partial f}{\partial h} \frac{\partial h}{\partial x_2} +\frac{\partial f}{\partial g} \frac{\partial g}{\partial h} \frac{\partial h}{\partial x_1} & +\frac{\partial f}{\partial g} \frac{\partial g}{\partial h} \frac{\partial h}{\partial x_2} \end{bmatrix}\\ & = \begin{bmatrix} @@ -473,10 +476,9 @@ function _edges(g, labels, nodes::Dict{Node}, done, x::Node) return end done[id] = true - if isnothing(x.op) - labels[id] = string(x.value) - else - labels[id] = "[" * String(x.op) * "] " * string(x.value) + labels[id] = @sprintf "%.2f" x.value + if !isnothing(x.op) + labels[id] = "[" * String(x.op) * "] " * labels[id] end for arg in x.args add_edge!(g, id, nodes[arg]) @@ -506,6 +508,34 @@ expr_graph, labels = graph(expr) # â•”â•â•¡ 9b9e5fc3-a8d0-4c42-91ae-25dd01bc7d7e gplot(expr_graph, nodelabel = labels) +# â•”â•â•¡ 54af5dab-e669-45eb-b6b5-44c46f7258b4 +md""" +#### Combinaison des dérivées + +Que faire si plusieurs expressions dépendent d'une même variable. +Considérons l'exemple ``f(x) = \sin(x)\cos(x)`` qui correspond à ``f(g,h) = gh``, ``g(x) = \sin(x)`` et ``h(x) = \cos(x)``. +La chain rule donne +```math +\begin{align} + f'(x) & = \frac{\partial f}{\partial g}g'(x) + \frac{\partial f}{\partial h}h'(x) +\end{align} +``` +Une fois la valeur ``\partial f / \partial g`` calculée, on peut la multiplier par ``g'(x)`` pour avoir la première partie partie de ``f'(x)``. Idem pour ``h``. Ces deux contributions seront calculée séparément lors de la backward pass sur le noeud ``g`` et ``h``. On voit par la formule de la chain rule que ces deux contributions doivent être sommées. +Lors de la backward pass, on initialise donc toutes les dérivées à 0. Pour chaque contribution, on ajoute la dérivée avec `+=`. On s'assure ensuite qu'on ne procède pas à la backward pass sur un noeud avant qu'il ait fini d'accumuler les contributions de toutes les expressions qui en dépendent via un *tri topologique*. +""" + +# â•”â•â•¡ 842df050-e0be-441d-b84e-7f0575eac227 +x_node = Node(1) + +# â•”â•â•¡ 36fcfd5e-c521-42fc-96cd-93787e657627 +sin_cos = sin(x_node) * cos(x_node) + +# â•”â•â•¡ f1da5c61-9862-44f6-84e4-96217736e1cb +sin_cos_graph, sin_cos_labels = graph(sin_cos) + +# â•”â•â•¡ 363af09c-3cc9-440d-9b70-1da8f6a70913 +gplot(sin_cos_graph, nodelabel = sin_cos_labels) + # â•”â•â•¡ bd705ddd-0d00-41a0-aa55-e82daad4133d md"### Backward pass : Calcul des dérivées" @@ -940,6 +970,7 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] @@ -959,7 +990,7 @@ PLUTO_MANIFEST_TOML_CONTENTS = """ julia_version = "1.11.1" manifest_format = "2.0" -project_hash = "5c3de1419124fc7d203b55effaa59184d76d6b7e" +project_hash = "647c126887a2ffbf11df74e3cf03bb0463ccc843" [[deps.AbstractPlutoDingetjes]] deps = ["Pkg"] @@ -2815,6 +2846,7 @@ version = "1.4.1+1" # â• â•4a807c44-f33e-4a58-99b3-261c392be891 # â• â•b5f3c2b1-c86a-48e4-973b-ee639d784936 # ╟─db28bb45-3418-4080-a0fc-9136fc0196a5 +# ╟─494fc7d7-c622-41d1-91d8-3dc1fbd2f244 # ╟─4e1ac5fc-c684-42e1-9c99-3120021eb19a # ╟─56b32132-113f-459f-b1d9-abb8f439a40b # â• â•4931adf1-8771-4708-833e-d05c05884969 @@ -2829,6 +2861,11 @@ version = "1.4.1+1" # ╟─7578be43-8dbe-4041-adc7-275f06057bfe # â• â•69298293-c9fc-432f-9c3c-5da7ce710334 # â• â•9b9e5fc3-a8d0-4c42-91ae-25dd01bc7d7e +# ╟─54af5dab-e669-45eb-b6b5-44c46f7258b4 +# â• â•842df050-e0be-441d-b84e-7f0575eac227 +# â• â•36fcfd5e-c521-42fc-96cd-93787e657627 +# â• â•f1da5c61-9862-44f6-84e4-96217736e1cb +# â• â•363af09c-3cc9-440d-9b70-1da8f6a70913 # ╟─bd705ddd-0d00-41a0-aa55-e82daad4133d # ╟─d8052188-f2fa-4ad8-935f-581eea164bda # â• â•1e08b49d-03fe-4fb3-a8ba-3a00e1374b32 -- GitLab