slaveOftime
machine_learningtorchsharpfsharpnotes

Make more names - part 2

2023-03-21

id: c1903b5b-cea8-4e16-9b14-7ad2e5e54bd1 title: Make more names - part 2 keywords: machine_learning,torchsharp,fsharp,notes description: Yeah... maybe it is time to start to learn machine learning. Will follow the star Andrej Karpathy createTime: 2023-03-21 isHidden: false

Refactor the first learning notes followed with https://www.youtube.com/watch?v=P6sfmUTpUmc&t=3014s

// Install dependencies
#r "nuget:Plotly.NET.Interactive"
#r "nuget:TorchSharp,0.99.3"
#r "nuget:libtorch-cuda-11.7-win-x64,1.13.0.1"
#r "nuget:Microsoft.DotNet.Interactive.Formatting,*-*"
Installed Packages
  • libtorch-cuda-11.7-win-x64, 1.13.0.1
  • Microsoft.DotNet.Interactive.Formatting, 1.0.0-beta.23205.1
  • Plotly.NET.Interactive, 4.1.0
  • TorchSharp, 0.99.3
Loading extensions from `C:\Users\cnbinwew\.nuget\packages\plotly.net.interactive\4.1.0\interactive-extensions\dotnet\Plotly.NET.Interactive.dll`
open System
open System.IO

open Plotly.NET
open TorchSharp
open type TorchSharp.torch.nn.functional

open Microsoft.DotNet.Interactive.Formatting

Formatter.SetPreferredMimeTypesFor(typeof, "text/plain")
Formatter.Register(fun (x:torch.Tensor) -> x.ToString(TorchSharp.TensorStringStyle.Default))

let print x = Formatter.ToDisplayString x |> printfn "%s"
let (@) (x: torch.Tensor) (y: torch.Tensor) = x.matmul y
let (^) x y = Math.Pow(x, y)
let scalar (x: float)= Scalar.op_Implicit x
let words = File.ReadAllLines "MakeMore.names.txt" |> Seq.sortBy (fun _ -> Random.Shared.Next()) |> Seq.toList
words |> Seq.take 5
[ giulio, yaretsi, hailyn, evelyna, avalise ]
let chars =
    let set = System.Collections.Generic.HashSet()
    words |> Seq.iter (fun word -> word |> Seq.iter (set.Add >> ignore))
    set |> Seq.sort |> Seq.toList

let lookupTable =
    Map.ofList [
        '.', 0
        for i, c in List.indexed chars do c, i + 1
    ]

let size = lookupTable.Count
let ctoi c = Map.find c lookupTable
let itoc i = lookupTable |> Map.pick (fun k x -> if x = i then Some k else None)
let n_embed = 10 // the dimensionality of the character embedding vectors
let n_hidden = 100 // the number of neurons in the hidden layer of the MLP
let block_size = 3

let g = torch.Generator().manual_seed(2122123) // for reproducibility
let X, Y =
    [|
        for word in words do
            let iend = size - 1
            let mutable context = [for _ in 1..block_size -> 0]
            for c in word do
                let ix = ctoi c
                List.toArray context, ix
                context <- List.append context[1..] [ix]
            List.toArray context, 0
    |]
    |> Array.unzip
    |> fun (x, y) -> 
        torch.tensor(array2D x),
        torch.tensor(y)
// Check the input and label pair
torch.cat(
    System.Collections.Generic.List [
        X[[|0L..20L|]]
        Y[[|0L..20L|]].view(-1, 1)  
    ],
    1
).data()
|> Seq.chunkBySize 4
|> Seq.iter (fun row ->
    printfn "%s => %s" (row[..2] |> Seq.map itoc |> String.Concat) (itoc row[3] |> string)
)
... => g
..g => i
.gi => u
giu => l
iul => i
uli => o
lio => .
... => y
..y => a
.ya => r
yar => e
are => t
ret => s
ets => i
tsi => .
... => h
..h => a
.ha => i
hai => l
ail => y
ily => n
// Training split, test split
// 90%             10%

let total = words.Length
let trainCount = float total * 0.9 |> int
let testCount = float total * 0.1 |> int

let X_train = X[torch.arange(trainCount)]
let Y_train = Y[torch.arange(trainCount)]
let X_test = X[torch.arange(trainCount, trainCount + testCount)]
let Y_test = Y[torch.arange(trainCount, trainCount + testCount)]
type ILayer =
    abstract member Forward: x: torch.Tensor -> torch.Tensor
    abstract member Parameters: torch.Tensor list
    abstract member Out: torch.Tensor
type Linear(fanIn: int, fanOut: int, generator: torch.Generator, ?withBias) =
    let mutable out = Unchecked.defaultof
    
    let mutable weight = torch.randn(fanIn, fanOut, generator = generator) / scalar(fanIn ^ 0.5)
    let bias = if defaultArg withBias true then Some(torch.zeros(fanOut)) else None

    member _.UpdateWeight(fn) = weight <- fn weight

    interface ILayer with
        member _.Forward(x) =
            out <- x @ weight
            out <-
                match bias with
                | None -> out
                | Some bias -> out + bias
            out

        member _.Parameters = [
            weight
            match bias with
            | None -> ()
            | Some bias -> bias
        ]

        member _.Out = out
type BatchNorm1d(dim: int, ?eps, ?momentum) as this =
    let mutable out = Unchecked.defaultof

    let eps = defaultArg eps 1e-5 |> scalar
    let momentum = defaultArg momentum 0.1 // 动量,推进力
    
    let mutable gamma = torch.ones(dim)
    let beta = torch.zeros(dim)
    
    let mutable running_mean = torch.zeros(dim)
    let mutable running_var = torch.ones(dim)

    member val IsTraining = true with get, set

    member _.UpdateGamma(fn) = gamma <- fn gamma

    member _.RuningVar = running_var
    member _.RuningMean = running_mean
    member _.Gamma = gamma
    member _.Beta = beta

    interface ILayer with
        member _.Forward(x: torch.Tensor) = 
            let xmean = // 平均值
                if this.IsTraining then x.mean([| 0 |], keepdim = true)
                else running_mean
            let xvar = // 方差 https://pytorch.org/docs/stable/generated/torch.var.html?highlight=var#torch.var 数的离散程度
                if this.IsTraining then x.var(0, keepdim = true, unbiased = true)
                else running_var

            let xhat = (x - xmean) / (xvar + eps).sqrt() // Normalize to unit variance
            out <- xhat * gamma + beta

            if this.IsTraining then
                use _ = torch.no_grad()
                running_mean <- scalar(1. - momentum) * running_mean + scalar(momentum) * xmean
                running_var <- scalar(1. - momentum) * running_var + scalar(momentum) * xvar

            out

        member _.Parameters = [ gamma; beta ]

        member _.Out = out
type Tanh() =
    let mutable out = Unchecked.defaultof

    interface ILayer with
        member _.Forward(x: torch.Tensor) = 
            out <- torch.tanh(x)
            out

        member _.Parameters = []

        member _.Out = out
// Build layers

let C = torch.randn(size, n_embed, generator = g)

// let layers: ILayer list = [
//     Linear(n_embed * block_size, n_hidden, generator = g); Tanh()
//     Linear(n_hidden, n_hidden, generator = g); Tanh()
//     Linear(n_hidden, n_hidden, generator = g); Tanh()
//     Linear(n_hidden, n_hidden, generator = g); Tanh()
//     Linear(n_hidden, n_hidden, generator = g); Tanh()
//     Linear(n_hidden, size, generator = g)
// ]

let layers: ILayer list = [
    Linear(n_embed * block_size, n_hidden, generator = g); BatchNorm1d(n_hidden); Tanh()
    Linear(n_hidden, n_hidden, generator = g); BatchNorm1d(n_hidden); Tanh()
    Linear(n_hidden, n_hidden, generator = g); BatchNorm1d(n_hidden); Tanh()
    Linear(n_hidden, n_hidden, generator = g); BatchNorm1d(n_hidden); Tanh()
    Linear(n_hidden, n_hidden, generator = g); BatchNorm1d(n_hidden); Tanh()
    Linear(n_hidden, size, generator = g); BatchNorm1d(size)
]

do  
    use _ = torch.no_grad()
    // Make the last Linear layer less confident
    // (layers |> List.last :?> Linear).UpdateWeight(fun x -> x * scalar(0.1))
    (layers |> List.last :?> BatchNorm1d).UpdateGamma(fun x -> x * scalar(0.1))
    // Improve confident for other Linear layers
    for layer in layers[0..layers.Length-2] do
        match layer with
        | :? Linear as linear -> linear.UpdateWeight(fun w -> w * scalar(5. / 3.))
        | _ -> ()
// Prepare parameters

let parameters = [
    C
    yield! layers |> Seq.map (fun x -> x.Parameters) |> Seq.concat
]

parameters |> Seq.iter (fun p -> p.requires_grad <- true)


let createLogits isTraining (x: torch.Tensor) =
    let embed = C[x.long()] // embed the characters into vectors
    let mutable x = embed.view(embed.shape[0], -1) // concatenate the vectors
    
    for layer in layers do
        match layer with
        | :? BatchNorm1d as b -> b.IsTraining <- isTraining
        | _ -> ()
        x <- layer.Forward(x)

    x

let calcLoss (target: torch.Tensor) (input : torch.Tensor) = cross_entropy(input, target.long())


printfn "Total parameters %d" (parameters |> Seq.sumBy (fun x -> x.NumberOfElements))
Total parameters 47551
// Used to keep track all the loss on every epoch
let lossi = System.Collections.Generic.List()
let upgradeToData = System.Collections.Generic.Dictionary>()
let epochs = 100_000
let batchSize = 32

// Start the training
for i in 1..epochs do
    // mini batch, get a batch of training set for training
    let ix = torch.randint(0, int X_train.shape[0], [| batchSize |], generator = g)
    let Xb, Yb = X_train[ix], Y_train[ix]

    // forward pass
    let logits = createLogits true Xb
    let loss = calcLoss Yb logits
    
    // backward pass
    for layer in layers do
        layer.Out.retain_grad() // For debug only

    for p in parameters do
        if p.grad() <> null then p.grad().zero_() |> ignore
    loss.backward()

    // Calculate learning rates
    let learningRate =
        if i < 20_000 then 0.1
        else 0.01

    // update
    for p in parameters do
        let newData = p - scalar(learningRate) * p.grad()
        p.data().CopyFrom(newData.data().ToArray())

    lossi.Add(loss.item())

    use _ = torch.no_grad()
    for i, p in List.indexed parameters do
        if upgradeToData.ContainsKey(i) |> not then upgradeToData[i] <- System.Collections.Generic.List()
        upgradeToData[i].Add((scalar(learningRate) * p.grad().std() / p.std()).log10().item())

    if i <> 0 && i % 10_000 = 0 then
        printfn $"loss = {loss.item()} \t learning rate = {learningRate}"
loss = 2.1214485 	 learning rate = 0.1
loss = 2.2190986 	 learning rate = 0.01
loss = 1.9041289 	 learning rate = 0.01
loss = 1.8256818 	 learning rate = 0.01
loss = 1.9673663 	 learning rate = 0.01
loss = 2.1136038 	 learning rate = 0.01
loss = 1.988074 	 learning rate = 0.01
loss = 1.7320899 	 learning rate = 0.01
loss = 1.6809943 	 learning rate = 0.01
loss = 1.811196 	 learning rate = 0.01
let lossiY = torch.tensor(lossi.ToArray()).view(-1, 1000).mean([|1L|]).data()
Chart.Line([1L..lossiY.Count], lossiY)
|> Chart.withSize(900, 400)
// The final loss
createLogits false X_train |> calcLoss Y_train
[], type = Float32, device = cpu, value = 1.8475
// The loss for the dev set
createLogits false X_test|> calcLoss Y_test
[], type = Float32, device = cpu, value = 2.2405
let generateNameByNetwork () =
    let mutable shouldContinue = true
    // used to predict the next char, <<< => ?
    let mutable context = [| for _ in 1..block_size -> 0L |]
    let name = Text.StringBuilder()

    while shouldContinue do
        let logits = createLogits false (torch.tensor(context ).view(-1, block_size))
        let probs = softmax(logits, dim = 1)
        // Pick one sample from the row, according to the probobility in row
        let ix = torch.multinomial(probs, num_samples = 1, replacement = true, generator = g).item()

        context <- Array.append context[1..] [|ix|]
        
        if int ix = 0 then
            shouldContinue <- false
        else
            ix |> int |> itoc |> name.Append |> ignore

    name.ToString()

[1..5] |> Seq.iter (ignore >> generateNameByNetwork >> print)
laya
ten
lucilla
herleigh
micharmoni
layers
|> Seq.indexed
|> Seq.choose (fun (i, layer) ->
    match layer with
    | :? Tanh as layer ->
        let x = (layer :> ILayer).Out.detach()
        let histogram = torch.histc(x)
        let saturate = x.abs().greater(torch.tensor(0.97)).float().mean() * scalar(100)
        printfn $"mean: %.4f{x.mean().data()[0]} \t std: %.4f{x.std().data()[0]} \t satuate: %.2f{saturate.data()[0]}%%"
        Chart.Line([1..100], histogram.data(), Name = $"Layer {i} Tanh") |> Some
    | _ -> None
)
|> Chart.combine
|> Chart.withSize(900, 400)
mean: -0.0000 	 std: 0.7526 	 satuate: 20.00%
mean: -0.0038 	 std: 0.7244 	 satuate: 13.00%
mean: 0.0298 	 std: 0.7230 	 satuate: 17.00%
mean: 0.0401 	 std: 0.6991 	 satuate: 11.00%
mean: -0.0320 	 std: 0.6640 	 satuate: 4.00%
parameters
|> Seq.indexed
|> Seq.choose (fun (i, p) ->
    let g = p.grad()
    if p.ndim = 2 then
        printfn $"weight: {p.shape.ToDisplayString()} \t mean: {g.mean().data()[0]} \t std: {g.std().data()[0]} \t grade/data ratio: {(g.std() / p.std()).data()[0]}"
        let hisogram = torch.histc(g)
        Chart.Line([1..100], hisogram.data(), Name = $"{i} {p.shape.ToDisplayString()}") |> Some
    else
        None
)
|> Chart.combine
|> Chart.withSize(900, 400)
weight: [ 27, 10 ] 	 mean: -6.760712E-10 	 std: 0.024949072 	 grade/data ratio: 0.023801908
weight: [ 30, 100 ] 	 mean: 0.00072093663 	 std: 0.018865554 	 grade/data ratio: 0.056021456
weight: [ 100, 100 ] 	 mean: -0.00020415313 	 std: 0.012050368 	 grade/data ratio: 0.06051427
weight: [ 100, 100 ] 	 mean: 5.1357336E-05 	 std: 0.012066372 	 grade/data ratio: 0.061524145
weight: [ 100, 100 ] 	 mean: 4.1329444E-05 	 std: 0.012293761 	 grade/data ratio: 0.063286
weight: [ 100, 100 ] 	 mean: 4.8571485E-05 	 std: 0.010817464 	 grade/data ratio: 0.056862194
weight: [ 100, 27 ] 	 mean: -0.00031279188 	 std: 0.018326223 	 grade/data ratio: 0.077995434
[
    Chart.Line([1..upgradeToData[0].Count], [for _ in 1..upgradeToData[0].Count -> -3], Name = "guide line")
    for i, p in List.indexed parameters do
        if p.ndim = 2 then
            Chart.Line([1..upgradeToData[i].Count], upgradeToData[i], Name = $"parameter{i}")
]
|> Chart.combine
|> Chart.withSize(900, 400)