slaveOftime
machine_learningtorchsharpfsharpnotes

Make more names - part 1

2023-03-13

id: d638c6a7-34bb-4151-b31b-30b34730df42 title: Make more names - part 1 keywords: machine_learning,torchsharp,fsharp,notes description: Yeah... maybe it is time to start to learn machine learning. Will follow the star Andrej Karpathy createTime: 2023-03-13

// Install dependencies
#r "nuget:Plotly.NET.Interactive"
#r "nuget:TorchSharp,0.99.1"
#r "nuget:libtorch-cuda-11.7-win-x64, 1.13.0.1"
#r "nuget:Microsoft.DotNet.Interactive.Formatting, 1.0.0-beta.23152.2"
Installed Packages
  • libtorch-cuda-11.7-win-x64, 1.13.0.1
  • Microsoft.DotNet.Interactive.Formatting, 1.0.0-beta.23152.2
  • Plotly.NET.Interactive, 4.1.0
  • TorchSharp, 0.99.1
Loading extensions from `C:\Users\cnbinwew\.nuget\packages\plotly.net.interactive\4.1.0\interactive-extensions\dotnet\Plotly.NET.Interactive.dll`
open System
open System.IO

open Plotly.NET
open TorchSharp
open type TorchSharp.torch.nn.functional

open Microsoft.DotNet.Interactive.Formatting

Formatter.SetPreferredMimeTypesFor(typeof, "text/plain")
Formatter.Register(fun (x:torch.Tensor) -> x.ToString(TorchSharp.TensorStringStyle.Default))

let print x = Formatter.ToDisplayString x |> printfn "%s"
let (@) (x: torch.Tensor) (y: torch.Tensor) = x.matmul y
let words = File.ReadAllLines "MakeMore.names.txt" |> Seq.sortBy (fun _ -> Random.Shared.Next()) |> Seq.toList
words |> Seq.take 5
[ antonia, baron, danay, miron, audi ]
let chars =
    let set = System.Collections.Generic.HashSet()
    words |> Seq.iter (fun word -> word |> Seq.iter (set.Add >> ignore))
    set |> Seq.sort |> Seq.toList

let lookupTable =
    Map.ofList [
        '<', 0
        for i, c in List.indexed chars do c, i + 1
        '>', chars.Length + 1
    ]

let size = lookupTable.Count
let ctoi c = Map.find c lookupTable
let itoc i = lookupTable |> Map.pick (fun k x -> if x = i then Some k else None)
let blockSize = 3

let X, Y =
    [|
        for word in words do
            let iend = size - 1
            let mutable context = [for _ in 1..blockSize -> 0]
            for c in word do
                let ix = ctoi c
                List.toArray context, ix
                context <- List.append context[1..] [ix]
            List.toArray context, iend
    |]
    |> Array.unzip
    |> fun (x, y) -> 
        torch.tensor(array2D x),
        torch.tensor(y)
// Check the input and label pair
torch.cat(
    System.Collections.Generic.List [
        X[[|0L..20L|]]
        Y[[|0L..20L|]].view(-1, 1)  
    ],
    1
).data()
|> Seq.chunkBySize 4
|> Seq.iter (fun row ->
    printfn "%s => %s" (row[..2] |> Seq.map itoc |> String.Concat) (itoc row[3] |> string)
)
<<< => a
<<a => n
<an => t
ant => o
nto => n
ton => i
oni => a
nia => >
<<< => b
<<b => a
<ba => r
bar => o
aro => n
ron => >
<<< => d
<<d => a
<da => n
dan => a
ana => y
nay => >
<<< => m
let total = words.Length
let trainCount = float total * 0.8 |> int
let devCount = float total * 0.1 |> int
let testCount = float total * 0.1 |> int

// Training split, dev/validation split, test split
// 80%             10%                   10%
let X_train = X[torch.arange(trainCount)]
let Y_train = Y[torch.arange(trainCount)]
let X_dev = X[torch.arange(trainCount, trainCount + devCount)]
let Y_dev = Y[torch.arange(trainCount, trainCount + devCount)]
let X_test = X[torch.arange(trainCount + devCount, trainCount + devCount + testCount)]
let Y_test = Y[torch.arange(trainCount + devCount, trainCount + devCount + testCount)]
// Build MLP layer
let inputSize = 10
let outputSize = size
let layer1InputSize = 30
let layer1OutputSize = 200

let batchSize = 32

let g = torch.Generator().manual_seed(2147483647)
let C = torch.randn(size, inputSize, generator = g)
let W1 = torch.randn(layer1InputSize, layer1OutputSize, generator = g)
let b1 = torch.randn(layer1OutputSize, generator = g)

let W2 = torch.randn(layer1OutputSize, outputSize, generator = g)
let b2 = torch.randn(outputSize, generator = g)

let parameters = [C; W1; b1; W2; b2]
printfn $"total params: {parameters |> Seq.sumBy (fun x -> x.NumberOfElements)}"

for p in parameters do
    p.requires_grad <- true
total params: 12108
// What is embeding
let c = torch.randint(5, [| 3; 3 |])
// All the values in every row will be the index which will be used to map the row in c
// So, the value should not be bigger than the row nums of c
let x = torch.randint(2, [| 4; 4 |])
let e = c[x]
print c
print x
print e

print (e[0, 1, 1])
[3x3], type = Int64, device = cpu
 0 4 4
 0 2 3
 0 0 3

[4x4], type = Int64, device = cpu
 0 1 1 0
 1 0 1 1
 1 0 0 1
 0 0 0 0

[4x4x3], type = Int64, device = cpu
[0,..,..] =
 0 4 4
 0 2 3
 0 2 3
 0 4 4

[1,..,..] =
 0 2 3
 0 4 4
 0 2 3
 0 2 3

[2,..,..] =
 0 2 3
 0 4 4
 0 4 4
 0 2 3

[3,..,..] =
 0 4 4
 0 4 4
 0 4 4
 0 4 4

[], type = Int64, device = cpu, value = 2
let createLogits (x: torch.Tensor) =
    let embed = C[x.long()]
    let h = embed.view(-1, layer1InputSize) @ W1 + b1 |> torch.tanh
    h @ W2 + b2

let calcLoss (target: torch.Tensor) (input : torch.Tensor) = cross_entropy(input, target.long())
// Used to keep track all the loss on every epoch
let lossi = System.Collections.Generic.List()
// what is backward used to do?
// backward, calculate the grad and update x0 with sepecific step (learning rate)
let x0 = torch.tensor(2., requires_grad = true)
for _ in 1..10 do
   let y0 = x0 * x0 + Scalar.op_Implicit 1
   if x0.grad() <> null then x0.grad().zero_() |> ignore
   y0.backward()
   let newX0 = x0 - Scalar.op_Implicit 0.25 * x0.grad()
   x0.data().CopyFrom(newX0.data().ToArray())
   printfn $"x0 = %.3f{x0.item()}, grad = %.3f{x0.grad().item()}"
x0 = 1.000, grad = 4.000
x0 = 0.500, grad = 2.000
x0 = 0.250, grad = 1.000
x0 = 0.125, grad = 0.500
x0 = 0.062, grad = 0.250
x0 = 0.031, grad = 0.125
x0 = 0.016, grad = 0.062
x0 = 0.008, grad = 0.031
x0 = 0.004, grad = 0.016
x0 = 0.002, grad = 0.008
// Start the training
for i in 0..100_000-1 do
    // mini batch
    let ix = torch.randint(0, int X_train.shape[0], [| batchSize |])
    // forward pass
    let logits = createLogits X_train[ix] 
    let loss = calcLoss Y_train[ix] logits
    //printfn $"loss = {loss.item()}"
    lossi.Add(loss.item())
    
    let lr =
        if i < 30_000 then 0.1
        else if i < 40_000 then 0.05
        else if i < 60_000 then 0.01
        else 0.01
        |> Scalar.op_Implicit

    for p in parameters do
        if p.grad() <> null then p.grad().zero_() |> ignore

    loss.backward()

    for p in parameters do
        let newData = p - lr * p.grad()
        p.data().CopyFrom(newData.data().ToArray())

Chart.Line([0..lossi.Count-1], lossi)
// The final loss
createLogits X_train |> calcLoss Y_train
[], type = Float32, device = cpu, value = 2.4804
// The loss for the dev set
createLogits X_dev |> calcLoss Y_dev
[], type = Float32, device = cpu, value = 2.607
let generateNameByNetwork () =
    let mutable shouldContinue = true
    // used to predict the next char
    let mutable context = [| for _ in 1..blockSize -> 0L |]
    let name = Text.StringBuilder()

    while shouldContinue do
        let logits = createLogits (torch.tensor(context))
        let probs = softmax(logits, dim = 1)

        // Pick one sample from the row, according to the probobility in row
        let ix = torch.multinomial(probs, num_samples = 1, replacement = true, generator = g).item()

        context <- Array.append context[1..] [|ix|]
        //context |> print
        //Threading.Thread.Sleep 1000

        if int ix = size - 1 then
            shouldContinue <- false
        else
            ix |> int |> itoc |> name.Append |> ignore

    name.ToString()
[1..10] |> Seq.iter (ignore >> generateNameByNetwork >> print)
laed
lde
tt
aan
nri
edy
tnlenn
yysei
rra
mlaw

Do you like this post?