Skip to content
Extraits de code Groupes Projets
Valider 62285724 rédigé par Benoît Legat's avatar Benoît Legat
Parcourir les fichiers

Add lab for next week

parent 53c160fa
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
Pipeline #54204 en échec
[deps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
......
......@@ -20,6 +20,7 @@ function random_moon(num_data; noise = 0.1)
end
X, y = random_moon(num_data)
using Colors
function plot_moon(model, W, X, y)
col = [Colors.JULIA_LOGO_COLORS.red, Colors.JULIA_LOGO_COLORS.blue]
scatter(X[:, 1], X[:, 2], markerstrokewidth=0, color = col[round.(Int, (3 .+ y) / 2)], label = "")
......
include("lab.jl")
using Test, LinearAlgebra
X, y = random_moon(num_data)
W1 = rand(size(X, 2), num_hidden)
W2 = rand(num_hidden)
W = (W1, W2)
include(joinpath(@__DIR__, "reverse.jl"))
∇f = @time forward_diff(mse, identity_activation, W, X, y)
∇r = @time reverse_diff(mse, identity_activation, W, X, y)
# We should get a difference at the order of `1e-15` unless we got it wrong:
norm.(∇f .- ∇r)
@test all(∇f . ∇r)
∇f = @time forward_diff(mse, tanh_activation, W, X, y)
∇r = @time reverse_diff(mse, tanh_activation, W, X, y)
# We should get a difference at the order of `1e-15` unless we got it wrong:
norm.(∇f .- ∇r)
@test all(∇f . ∇r)
∇f = @time forward_diff(mse, relu_activation, W, X, y)
∇r = @time reverse_diff(mse, relu_activation, W, X, y)
norm.(∇f .- ∇r)
@test all(∇f . ∇r)
Y_encoded = one_hot_encode(y)
W = (rand(size(X, 2), num_hidden), rand(num_hidden, size(Y_encoded, 2)))
∇f = @time forward_diff(cross_entropy, relu_softmax, W, X, Y_encoded)
∇r = @time reverse_diff(cross_entropy, relu_softmax, W, X, Y_encoded)
norm.(∇f .- ∇r)
@test all(∇f . ∇r)
mutable struct Node
op::Union{Nothing,Symbol}
args::Vector{Node}
value::Float64
derivative::Float64
end
Node(op, args, value) where {T} = Node(op, args, value, 0.0)
Node(value) = Node(nothing, Node[], value)
Base.zero(::Node) = Node(0)
Base.:*(x::Node, y::Node) = Node(:*, [x, y], x.value * y.value)
Base.:+(x::Node, y::Node) = Node(:+, [x, y], x.value + y.value)
Base.:-(x::Node, y::Node) = Node(:-, [x, y], x.value - y.value)
Base.:/(x::Node, y::Number) = x * Node(inv(y))
Base.:^(x::Node, n::Integer) = Base.power_by_squaring(x, n)
function topo_sort!(visited, topo, f::Node)
if !(f in visited)
push!(visited, f)
for arg in f.args
topo_sort!(visited, topo, arg)
end
push!(topo, f)
end
end
function _backward!(f::Node)
if isnothing(f.op)
return
elseif f.op == :+
for arg in f.args
arg.derivative += f.derivative
end
elseif f.op == :- && length(f.args) == 2
f.args[1].derivative += f.derivative
f.args[2].derivative -= f.derivative
elseif f.op == :* && length(f.args) == 2
f.args[1].derivative += f.derivative * f.args[2].value
f.args[2].derivative += f.derivative * f.args[1].value
else
error("Operator `$(f.op)` not supported yet")
end
end
function backward!(f::Node)
topo = typeof(f)[]
topo_sort!(Set{typeof(f)}(), topo, f)
reverse!(topo)
for node in topo
node.derivative = 0
end
f.derivative = 1
for node in topo
_backward!(node)
end
return f
end
function reverse_diff(loss, model, W::Tuple, X, y)
W_nodes = broadcast.(Node, W)
expr = loss(model(W_nodes, Node.(X)), Node.(y))
backward!(expr)
return broadcast.(w -> w.derivative, W_nodes)
end
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Terminez d'abord l'édition de ce message.
Veuillez vous inscrire ou vous pour commenter