diff --git a/LabAutoDiff/Project.toml b/LabAutoDiff/Project.toml index e608bc0bd229a5ee3aae6a011c18206c3475e006..e61424faa7c9901d9c54c4e03cd49cfb26b2df1c 100644 --- a/LabAutoDiff/Project.toml +++ b/LabAutoDiff/Project.toml @@ -1,4 +1,5 @@ [deps] +Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" diff --git a/LabAutoDiff/lab.jl b/LabAutoDiff/lab.jl index 9b6d4e7a2654574bbb0c02eb05ece82ead9c8ea3..38c7d6d62b655ffe11b79c53bbed1b395b422c3d 100644 --- a/LabAutoDiff/lab.jl +++ b/LabAutoDiff/lab.jl @@ -20,6 +20,7 @@ function random_moon(num_data; noise = 0.1) end X, y = random_moon(num_data) +using Colors function plot_moon(model, W, X, y) col = [Colors.JULIA_LOGO_COLORS.red, Colors.JULIA_LOGO_COLORS.blue] scatter(X[:, 1], X[:, 2], markerstrokewidth=0, color = col[round.(Int, (3 .+ y) / 2)], label = "") diff --git a/LabAutoDiff/lab_reverse.jl b/LabAutoDiff/lab_reverse.jl new file mode 100644 index 0000000000000000000000000000000000000000..1c70e7610dc88876f8befdc24cec9207acfeb9cb --- /dev/null +++ b/LabAutoDiff/lab_reverse.jl @@ -0,0 +1,39 @@ +include("lab.jl") +using Test, LinearAlgebra + +X, y = random_moon(num_data) +W1 = rand(size(X, 2), num_hidden) +W2 = rand(num_hidden) +W = (W1, W2) + +include(joinpath(@__DIR__, "reverse.jl")) + +∇f = @time forward_diff(mse, identity_activation, W, X, y) +∇r = @time reverse_diff(mse, identity_activation, W, X, y) + +# We should get a difference at the order of `1e-15` unless we got it wrong: +norm.(∇f .- ∇r) +@test all(∇f .≈ ∇r) + + +∇f = @time forward_diff(mse, tanh_activation, W, X, y) +∇r = @time reverse_diff(mse, tanh_activation, W, X, y) + +# We should get a difference at the order of `1e-15` unless we got it wrong: +norm.(∇f .- ∇r) +@test all(∇f .≈ ∇r) + +∇f = @time forward_diff(mse, relu_activation, W, X, y) +∇r = @time reverse_diff(mse, relu_activation, W, X, y) + +norm.(∇f .- ∇r) +@test all(∇f .≈ ∇r) + +Y_encoded = one_hot_encode(y) +W = (rand(size(X, 2), num_hidden), rand(num_hidden, size(Y_encoded, 2))) + +∇f = @time forward_diff(cross_entropy, relu_softmax, W, X, Y_encoded) +∇r = @time reverse_diff(cross_entropy, relu_softmax, W, X, Y_encoded) + +norm.(∇f .- ∇r) +@test all(∇f .≈ ∇r) diff --git a/LabAutoDiff/reverse.jl b/LabAutoDiff/reverse.jl new file mode 100644 index 0000000000000000000000000000000000000000..857038723942848d56e528829cd1a547414a1e26 --- /dev/null +++ b/LabAutoDiff/reverse.jl @@ -0,0 +1,63 @@ +mutable struct Node + op::Union{Nothing,Symbol} + args::Vector{Node} + value::Float64 + derivative::Float64 +end +Node(op, args, value) where {T} = Node(op, args, value, 0.0) +Node(value) = Node(nothing, Node[], value) +Base.zero(::Node) = Node(0) +Base.:*(x::Node, y::Node) = Node(:*, [x, y], x.value * y.value) +Base.:+(x::Node, y::Node) = Node(:+, [x, y], x.value + y.value) +Base.:-(x::Node, y::Node) = Node(:-, [x, y], x.value - y.value) +Base.:/(x::Node, y::Number) = x * Node(inv(y)) +Base.:^(x::Node, n::Integer) = Base.power_by_squaring(x, n) + +function topo_sort!(visited, topo, f::Node) + if !(f in visited) + push!(visited, f) + for arg in f.args + topo_sort!(visited, topo, arg) + end + push!(topo, f) + end +end + +function _backward!(f::Node) + if isnothing(f.op) + return + elseif f.op == :+ + for arg in f.args + arg.derivative += f.derivative + end + elseif f.op == :- && length(f.args) == 2 + f.args[1].derivative += f.derivative + f.args[2].derivative -= f.derivative + elseif f.op == :* && length(f.args) == 2 + f.args[1].derivative += f.derivative * f.args[2].value + f.args[2].derivative += f.derivative * f.args[1].value + else + error("Operator `$(f.op)` not supported yet") + end +end + +function backward!(f::Node) + topo = typeof(f)[] + topo_sort!(Set{typeof(f)}(), topo, f) + reverse!(topo) + for node in topo + node.derivative = 0 + end + f.derivative = 1 + for node in topo + _backward!(node) + end + return f +end + +function reverse_diff(loss, model, W::Tuple, X, y) + W_nodes = broadcast.(Node, W) + expr = loss(model(W_nodes, Node.(X)), Node.(y)) + backward!(expr) + return broadcast.(w -> w.derivative, W_nodes) +end