DEV Community

Piotr Mikuła
Piotr Mikuła

Posted on • Updated on

Simple MNIST recognition in Julia using Flux

Brief introduction

Neural networks

Resources:
Paper:

Video:

Julia

Julia is open source compiled, high level programming language. (more about Julia)
It can be used for example in IDE's like VSCode with extension https://github.com/julia-vscode/julia-vscode or in Juno (which is build on Atom).

Flux

Flux is deep machine learning library written in Julia. More details can be readd here https://arxiv.org/pdf/1811.01457.pdf, at GitHub repository https://github.com/FluxML/Flux.jl or at home page https://fluxml.ai/.

MNIST

MNIST is dataset which contains handwritten number images

Program

Import libraries

To donwload and install libraries in Julia you can use Pkg.add function

#Example
using Pkg
Pkg.add("Flux")
Enter fullscreen mode Exit fullscreen mode

or use REPL and type ] add Flux.

Using libraries in project file

using Flux, MLDatasets, CUDA
using Flux: train!, onehotbatch
Enter fullscreen mode Exit fullscreen mode

Get data

MLDatasets library contains several vison datasets. Including MNIST.
MLDatasets.MNIST.traindata() returns array containing 60000 28x28 greyscale images and labels vector (numbers from 0 to 9).MLDatasets.MNIST.testdata() returns 10000 images and labels.

# Load training data (images, labels)
x_train, y_train = MLDatasets.MNIST.traindata()
# Load test data (images, labels)
x_test, y_test = MLDatasets.MNIST.testdata()
# Convert grayscale to float
x_train = Float32.(x_train)
# Create labels batch
y_train = Flux.onehotbatch(y_train, 0:9)
Enter fullscreen mode Exit fullscreen mode

Create model

Model contains 3 layers, where:

  • 1: layer containing 784 inputs 256 outputs
  • 2: layer containing 256 inputs and 10 outputs
model = Chain(
    Dense(784, 256, relu),
    Dense(256, 10, relu), softmax
)
Enter fullscreen mode Exit fullscreen mode

Training

Loss function cross-entropy. In project were used logitcrossentropy, as says Flux documentation

"This is mathematically equivalent to crossentropy(softmax(ŷ), y), but is more numerically stable than using functions crossentropy and softmax separately."
Flux documentation

Docs: https://fluxml.ai/Flux.jl/stable/models/losses/

Define loss function

loss(x, y) = Flux.Losses.logitcrossentropy(model(x), y)
Enter fullscreen mode Exit fullscreen mode

Optimizer Adam adaptative moment estimation, that combines loss function and parameters
Docs: https://fluxml.ai/Flux.jl/stable/training/optimisers

Define optimizer

optimizer = ADAM(0.0001)
Enter fullscreen mode Exit fullscreen mode

Training

parameters = params(model)
# flatten() function converts array 28x28x60000 into 784x60000 (28*28x60000)
train_data = [(Flux.flatten(x_train), Flux.flatten(y_train))]
# Range in loop can be used smaller
for i in 1:400
    Flux.train!(loss, parameters, train_data, optimizer)
end
Enter fullscreen mode Exit fullscreen mode

Checking results

test_data = [(Flux.flatten(x_test), y_test)]
accuracy = 0
for i in 1:length(y_test)
    if findmax(model(test_data[1][1][:, i]))[2] - 1  == y_test[i]
        accuracy = accuracy + 1
    end
end
println(accuracy / length(y_test))
Enter fullscreen mode Exit fullscreen mode

Result of test

Accuracy: 0.91

Full code can be found here GIST

Oldest comments (2)

Collapse
 
mariatrabazo profile image
MariaTrabazo

Hi!

I'm trying to execute this exact code, but I get the following error:

ERROR: LoadError: MethodError: no method matching _methods_by_ftype(::Type{Tuple{typeof(ChainRulesCore.rrule),Flux.Optimise.var"#15#21"{typeof(loss),Tuple{Array{Float32,2},Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}}}}}, ::Int64, ::UInt64, ::Bool, ::Base.RefValue{UInt64}, ::Base.RefValue{UInt64}, ::Ptr{Int32})
Closest candidates are:
_methods_by_ftype(::Any, ::Int64, ::UInt64) at reflection.jl:838
_methods_by_ftype(::Any, ::Int64, ::UInt64, !Matched::Array{UInt64,1}, !Matched::Array{UInt64,1}) at reflection.jl:841

Could someone tell me why please?

Thank you in advance :)

Collapse
 
piotrek124 profile image
Piotr Mikuła

Hi,
Could you please tell me on which line the error occurs?