Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ authors = ["michielstock <[email protected]>"]
version = "0.1.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5 changes: 3 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ using STMOZOO.Example

makedocs(sitename="STMO ZOO",
format = Documenter.HTML(),
modules=[Example], # add your module
modules=[Example, Softmax], # add your module
pages=Any[
"Example"=> "man/example.md", # add the page to your documentation
"Example" => "man/example.md", # add the page to your documentation
"Softmax" => "man/softmax.md",
])

#=
Expand Down
8 changes: 8 additions & 0 deletions docs/src/man/softmax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
## Softmax

Provides some basic functions for computing and sampling from the softmax.

```@docs
softmax
gumbel_max
```
135 changes: 135 additions & 0 deletions notebook/softmax.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
### A Pluto.jl notebook ###
# v0.12.4

using Markdown
using InteractiveUtils

# This Pluto notebook uses @bind for interactivity. When running this notebook outside of Pluto, the following 'mock version' of @bind gives bound variables a default value (instead of an error).
macro bind(def, element)
quote
local el = $(esc(element))
global $(esc(def)) = Core.applicable(Base.get, el) ? Base.get(el) : missing
el
end
end

# ╔═╡ 58c174f6-347c-11eb-3245-2b0f0eaaf5ac
using Plots, STMOZOO.Softmax, PlutoUI

# ╔═╡ e9eb4324-347b-11eb-3ba0-3948ac778eab
md"""
# Softmax as a MaxEnt model
The [[softmax]] is a popular activation function in machine learning and decision theory. In this project, we show how the softmax arises from a maximum entropy ([[MaxEnt]]) model. Furthermore, we show how we can efficiently sample from this distribution using the [[Gumbel-max trick]].
"""

# ╔═╡ 12e07860-347c-11eb-2620-89d253129c58
md"""
## Choosing items according to utility
Suppose we have a choice out of $n$ options. For example, we might need to choose what to have for dinner from our list or decide which of our many projects we want to spend time on. Not all these options are equally attractive: each option $i$ has an associated utility value $x_i$.

A straightforward decision model might just be choosing the option with the largest utility:

$$\max_{i\in{1,\ldots,n}} x_i\,.$$

This seems sensible though there are drawbacks with this approach:

- What if two options have nearly the same utility values? We would expect that both would have an approximately equal chance of being chosen. This is currently not the case.
- This is not a smooth function! Changing the utility values influences the decision process only if it changes the item with the largest utility value.
- We would like to have items with a lower utility value also be picked, at least some of the times!

In a different approach, we use optimization to choose a decision vector $\mathbf{q}\in\Delta^{n-1}$. In addition to maximizing the *average utility values* $\langle\mathbf{q},\mathbf{x}\rangle$, we also want to optimize the [[entropy]]:

$$H(\mathbf{q}) = -\sum_{i=1}^nq_i\log(q_i)\,.$$

This drives $\mathbf{q}$ to be as uniform as possible and there is a large history why this is a sensible approach. The [[Lagrangian]] of this problem is:

$$L(\mathbf{q};\kappa, \nu)=H(\mathbf{q}) + \kappa(\langle\mathbf{q},\mathbf{x}\rangle-u_\text{min}) + \nu(\sum_{i=1}^nx_i -1)\,.$$

Computing the partial derivative w.r.t. $q_i$:

$$\frac{\partial L(\mathbf{q};\kappa, \nu)}{\partial q^\star_i}=-\log(q^\star_i) - 1 + \kappa x_i + \nu=0$$

and setting equal to 0 we have

$$q^\star_i\propto \exp(\kappa x_i)\,.$$

So, keeping the $\mathbf{q}$ normalized, we obtain the *softmax*:

$$q_i = \frac{\exp(\kappa x_i)}{\sum_j\exp(\kappa x_j)}\,.$$

Here, $\kappa\ge 0$ is a tuning parameter that determines the dependency on utility.
"""

# ╔═╡ 60c00938-347c-11eb-2f77-4388badf9113
md"Here is an example using dinners with their respective preferences."

# ╔═╡ 36144028-347c-11eb-1734-f527482a6048
dinners_preference = [("rice-lentils", 10.0),
("dhal", 6.0),
("spaghetti", 8.5),
("chinese", 7.5),
("fries", 8.0),
("ribhouse", 9.5),
("hamburger", 8.0),
("pitta", 6.0),
("noodles", 7.0),
("chicken", 8.0),
("curry", 7.5),
("pizza (veg)", 7.0),
("pizza", 8.0)
]

# ╔═╡ 418c6d2c-347c-11eb-0c11-db8d090c5953
dinners = first.(dinners_preference)

# ╔═╡ 50208e18-347c-11eb-1628-6d9587fdf00f
preferences = last.(dinners_preference)

# ╔═╡ 5d734344-347c-11eb-2c10-b5c0adfac424
bar(dinners, preferences)

# ╔═╡ a7bee2aa-347c-11eb-109b-f39f17158444
@bind κ Slider(0.01:0.1:10, default=1)

# ╔═╡ adc7a786-347c-11eb-175f-db354866f500
q = softmax(preferences; κ=κ)

# ╔═╡ d81e8afe-347c-11eb-2c13-a5a3219051ae
bar(dinners, q)

# ╔═╡ 997c7554-347c-11eb-322e-27691246c726
md"""
## Gumbel max trick

Given the optimal choice distribution $\mathbf{q}$, how can we sample choices from this? Using the [[Gumbel-max trick]]! This is a simple way to [[sampling|sample]] form a discrete [[probability]] distributions determined by unnormalized log-probabilities, i.e.

$$p_k\sim \exp(x_k)$$

Just use

$$y=\text{argmax}_{i\in 1,\ldots, K} x_i + g_i$$

where $g_i$ follows a [[Gumbel distribution]].
"""

# ╔═╡ 73a46a84-347d-11eb-10fa-3365cbee1205
gumbel_max(dinners, preferences)

# ╔═╡ 80e7b656-347d-11eb-2546-a36f30213001
sampled_dinners = [gumbel_max(dinners, preferences) for i in 1:1000]

# ╔═╡ Cell order:
# ╟─e9eb4324-347b-11eb-3ba0-3948ac778eab
# ╟─12e07860-347c-11eb-2620-89d253129c58
# ╟─60c00938-347c-11eb-2f77-4388badf9113
# ╠═36144028-347c-11eb-1734-f527482a6048
# ╠═418c6d2c-347c-11eb-0c11-db8d090c5953
# ╠═50208e18-347c-11eb-1628-6d9587fdf00f
# ╠═58c174f6-347c-11eb-3245-2b0f0eaaf5ac
# ╠═5d734344-347c-11eb-2c10-b5c0adfac424
# ╠═a7bee2aa-347c-11eb-109b-f39f17158444
# ╠═adc7a786-347c-11eb-175f-db354866f500
# ╠═d81e8afe-347c-11eb-2c13-a5a3219051ae
# ╟─997c7554-347c-11eb-322e-27691246c726
# ╠═73a46a84-347d-11eb-10fa-3365cbee1205
# ╠═80e7b656-347d-11eb-2546-a36f30213001
3 changes: 2 additions & 1 deletion src/STMOZOO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module STMOZOO

# execute your source file and export the module you made
include("example.jl")
export Example
include("softmax.jl")
export Example, Softmax

end # module
50 changes: 50 additions & 0 deletions src/softmax.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Michiel Stock
# Example of a source code file implementing a module.


# all your code is part of the module you are implementing
module Softmax

# you have to import everything you need for your module to work
# if you use a new package, don't forget to add it in the package manager
using Distributions: Gumbel

# export all functions that are relevant for the user
export softmax, gumbel_max

"""
softmax(x::Vector; κ::Number=1.0)

Computes the softmax for a vector `x`. `κ` is a hyperparameter that
determines the trade-off between utility and entropy.
"""
function softmax(x::Vector; κ::Number=1.0)
q = exp.(κ .* x)
q ./= sum(q)
return q
end

"""
gumbel_max(items::Vector, x::Vector; κ::Number=1.0)

Samples an item from `items` using the softmax given utilities `x`.
`κ` is a hyperparameter that determines the trade-off between utility
and entropy.
"""
function gumbel_max(items::Vector, x::Vector; κ::Number=1.0)
@assert length(items) == length(x) "length of `items` and `x` do not match"
i = κ .* x .+ rand(Gumbel(), length(x)) |> argmax
return items[i]
end

"""
gumbel_max(x::Vector; κ::Number=1.0)

Samples an item using the softmax given utilities `x`.
Returns the indice of the chosen item.
`κ` is a hyperparameter that determines the trade-off between utility
and entropy.
"""
gumbel_max(x::Vector; κ::Number=1.0) = κ .* x .+ rand(Gumbel(), length(x)) |> argmax

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Test

include("example.jl")
include("softmax.jl")
# add here the file with your unit tests
25 changes: 25 additions & 0 deletions test/softmax.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@testset "Softmax" begin
using STMOZOO.Softmax

x = [1.5, 0.1, -1.2]
items = [:A, :B, :C]

@testset "softmax" begin
qx = softmax(x)

@test qx isa Vector
@test all(qx .≥ 0.0)
@test sum(qx) ≈ 1.0
@test qx[1] > qx[2]

@test maximum(softmax(x, κ=20)) ≈ 1.0
end

@testset "gumbel" begin

@test (gumbel_max(items, x) ∈ items)
@test (items[gumbel_max(x)] ∈ items)
@test (gumbel_max(items, x; κ=20) == :A)
end

end