diff --git a/LabAutoDiff/Project.toml b/LabAutoDiff/Project.toml
index e608bc0bd229a5ee3aae6a011c18206c3475e006..e61424faa7c9901d9c54c4e03cd49cfb26b2df1c 100644
--- a/LabAutoDiff/Project.toml
+++ b/LabAutoDiff/Project.toml
@@ -1,4 +1,5 @@
 [deps]
+Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
 MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
 OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
 Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
diff --git a/LabAutoDiff/lab.jl b/LabAutoDiff/lab.jl
index 9b6d4e7a2654574bbb0c02eb05ece82ead9c8ea3..38c7d6d62b655ffe11b79c53bbed1b395b422c3d 100644
--- a/LabAutoDiff/lab.jl
+++ b/LabAutoDiff/lab.jl
@@ -20,6 +20,7 @@ function random_moon(num_data; noise = 0.1)
 end
 X, y = random_moon(num_data)
 
+using Colors
 function plot_moon(model, W, X, y)
 	col = [Colors.JULIA_LOGO_COLORS.red, Colors.JULIA_LOGO_COLORS.blue]
 	scatter(X[:, 1], X[:, 2], markerstrokewidth=0, color = col[round.(Int, (3 .+ y) / 2)], label = "")
diff --git a/LabAutoDiff/lab_reverse.jl b/LabAutoDiff/lab_reverse.jl
new file mode 100644
index 0000000000000000000000000000000000000000..1c70e7610dc88876f8befdc24cec9207acfeb9cb
--- /dev/null
+++ b/LabAutoDiff/lab_reverse.jl
@@ -0,0 +1,39 @@
+include("lab.jl")
+using Test, LinearAlgebra
+
+X, y = random_moon(num_data)
+W1 = rand(size(X, 2), num_hidden)
+W2 = rand(num_hidden)
+W = (W1, W2)
+
+include(joinpath(@__DIR__, "reverse.jl"))
+
+∇f = @time forward_diff(mse, identity_activation, W, X, y)
+∇r = @time reverse_diff(mse, identity_activation, W, X, y)
+
+# We should get a difference at the order of `1e-15` unless we got it wrong:
+norm.(∇f .- ∇r)
+@test all(∇f .≈ ∇r)
+
+
+∇f = @time forward_diff(mse, tanh_activation, W, X, y)
+∇r = @time reverse_diff(mse, tanh_activation, W, X, y)
+
+# We should get a difference at the order of `1e-15` unless we got it wrong:
+norm.(∇f .- ∇r)
+@test all(∇f .≈ ∇r)
+
+∇f = @time forward_diff(mse, relu_activation, W, X, y)
+∇r = @time reverse_diff(mse, relu_activation, W, X, y)
+
+norm.(∇f .- ∇r)
+@test all(∇f .≈ ∇r)
+
+Y_encoded = one_hot_encode(y)
+W = (rand(size(X, 2), num_hidden), rand(num_hidden, size(Y_encoded, 2)))
+
+∇f = @time forward_diff(cross_entropy, relu_softmax, W, X, Y_encoded)
+∇r = @time reverse_diff(cross_entropy, relu_softmax, W, X, Y_encoded)
+
+norm.(∇f .- ∇r)
+@test all(∇f .≈ ∇r)
diff --git a/LabAutoDiff/reverse.jl b/LabAutoDiff/reverse.jl
new file mode 100644
index 0000000000000000000000000000000000000000..857038723942848d56e528829cd1a547414a1e26
--- /dev/null
+++ b/LabAutoDiff/reverse.jl
@@ -0,0 +1,63 @@
+mutable struct Node
+    op::Union{Nothing,Symbol}
+    args::Vector{Node}
+    value::Float64
+    derivative::Float64
+end
+Node(op, args, value) where {T} = Node(op, args, value, 0.0)
+Node(value) = Node(nothing, Node[], value)
+Base.zero(::Node) = Node(0)
+Base.:*(x::Node, y::Node) = Node(:*, [x, y], x.value * y.value)
+Base.:+(x::Node, y::Node) = Node(:+, [x, y], x.value + y.value)
+Base.:-(x::Node, y::Node) = Node(:-, [x, y], x.value - y.value)
+Base.:/(x::Node, y::Number) = x * Node(inv(y))
+Base.:^(x::Node, n::Integer) = Base.power_by_squaring(x, n)
+
+function topo_sort!(visited, topo, f::Node)
+	if !(f in visited)
+		push!(visited, f)
+		for arg in f.args
+			topo_sort!(visited, topo, arg)
+		end
+		push!(topo, f)
+	end
+end
+
+function _backward!(f::Node)
+	if isnothing(f.op)
+		return
+	elseif f.op == :+
+		for arg in f.args
+			arg.derivative += f.derivative
+		end
+	elseif f.op == :- && length(f.args) == 2
+		f.args[1].derivative += f.derivative
+		f.args[2].derivative -= f.derivative
+	elseif f.op == :* && length(f.args) == 2
+		f.args[1].derivative += f.derivative * f.args[2].value
+		f.args[2].derivative += f.derivative * f.args[1].value
+	else
+		error("Operator `$(f.op)` not supported yet")
+	end
+end
+
+function backward!(f::Node)
+	topo = typeof(f)[]
+	topo_sort!(Set{typeof(f)}(), topo, f)
+	reverse!(topo)
+	for node in topo
+		node.derivative = 0
+	end
+	f.derivative = 1
+	for node in topo
+		_backward!(node)
+	end
+	return f
+end
+
+function reverse_diff(loss, model, W::Tuple, X, y)
+	W_nodes = broadcast.(Node, W)
+	expr = loss(model(W_nodes, Node.(X)), Node.(y))
+	backward!(expr)
+	return broadcast.(w -> w.derivative, W_nodes)
+end