tools/ubot-plugins/ubot-mingpt-plugin/src/main.rs
author unC0Rr
Fri, 03 Feb 2023 14:44:33 +0100
branchtransitional_engine
changeset 15916 e82de0410da5
parent 15793 96443d9b48c9
permissions -rw-r--r--
Rework how rules are defined, add transformations for tiles
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
15791
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
     1
/* This example uses the tinyshakespeare dataset which can be downloaded at:
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
     2
   https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
     3
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
     4
   This is mostly a rust port of https://github.com/karpathy/minGPT
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
     5
*/
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
     6
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
     7
extern crate tch;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
     8
use anyhow::{bail, Result as AHResult};
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
     9
use std::{io, io::Write};
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    10
use tch::data::TextData;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    11
use tch::nn::{ModuleT, OptimizerConfig};
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    12
use tch::{nn, Device, IndexOp, Kind, Tensor};
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    13
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    14
use futures::prelude::*;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    15
use lapin::{options::*, types::FieldTable, BasicProperties, Connection, ConnectionProperties};
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    16
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    17
use tokio_amqp::*;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    18
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    19
const LEARNING_RATE: f64 = 0.0003;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    20
const BLOCK_SIZE: i64 = 128;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    21
const BATCH_SIZE: i64 = 64;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    22
const EPOCHS: i64 = 100;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    23
const SAMPLING_LEN: i64 = 512;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    24
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    25
#[derive(Debug, Copy, Clone)]
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    26
struct Config {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    27
    vocab_size: i64,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    28
    n_embd: i64,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    29
    n_head: i64,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    30
    n_layer: i64,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    31
    block_size: i64,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    32
    attn_pdrop: f64,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    33
    resid_pdrop: f64,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    34
    embd_pdrop: f64,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    35
}
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    36
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    37
// Weight decay only applies to the weight matrixes in the linear layers
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    38
const NO_WEIGHT_DECAY_GROUP: usize = 0;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    39
const WEIGHT_DECAY_GROUP: usize = 1;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    40
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    41
// Custom linear layer so that different groups can be used for weight
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    42
// and biases.
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    43
#[derive(Debug)]
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    44
struct Linear {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    45
    pub ws: Tensor,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    46
    pub bs: Tensor,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    47
}
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    48
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    49
impl nn::Module for Linear {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    50
    fn forward(&self, xs: &Tensor) -> Tensor {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    51
        xs.matmul(&self.ws.tr()) + &self.bs
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    52
    }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    53
}
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    54
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    55
fn linear(vs: nn::Path, in_dim: i64, out_dim: i64) -> Linear {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    56
    let wd = vs.set_group(WEIGHT_DECAY_GROUP);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    57
    let no_wd = vs.set_group(NO_WEIGHT_DECAY_GROUP);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    58
    Linear {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    59
        ws: wd.randn("weight", &[out_dim, in_dim], 0.0, 0.02),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    60
        bs: no_wd.zeros("bias", &[out_dim]),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    61
    }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    62
}
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    63
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    64
fn linear_no_bias(vs: nn::Path, in_dim: i64, out_dim: i64) -> Linear {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    65
    let wd = vs.set_group(WEIGHT_DECAY_GROUP);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    66
    let no_wd = vs.set_group(NO_WEIGHT_DECAY_GROUP);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    67
    Linear {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    68
        ws: wd.randn("weight", &[out_dim, in_dim], 0.0, 0.02),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    69
        bs: no_wd.zeros_no_train("bias", &[out_dim]),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    70
    }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    71
}
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    72
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    73
fn causal_self_attention(p: &nn::Path, cfg: Config) -> impl ModuleT {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    74
    let key = linear(p / "key", cfg.n_embd, cfg.n_embd);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    75
    let query = linear(p / "query", cfg.n_embd, cfg.n_embd);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    76
    let value = linear(p / "value", cfg.n_embd, cfg.n_embd);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    77
    let proj = linear(p / "proj", cfg.n_embd, cfg.n_embd);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    78
    let mask_init =
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    79
        Tensor::ones(&[cfg.block_size, cfg.block_size], (Kind::Float, p.device())).tril(0);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    80
    let mask_init = mask_init.view([1, 1, cfg.block_size, cfg.block_size]);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    81
    // let mask = p.var_copy("mask", &mask_init);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    82
    let mask = mask_init;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    83
    nn::func_t(move |xs, train| {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    84
        let (sz_b, sz_t, sz_c) = xs.size3().unwrap();
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    85
        let sizes = [sz_b, sz_t, cfg.n_head, sz_c / cfg.n_head];
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    86
        let k = xs.apply(&key).view(sizes).transpose(1, 2);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    87
        let q = xs.apply(&query).view(sizes).transpose(1, 2);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    88
        let v = xs.apply(&value).view(sizes).transpose(1, 2);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    89
        let att = q.matmul(&k.transpose(-2, -1)) * (1.0 / f64::sqrt(sizes[3] as f64));
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    90
        let att = att.masked_fill(
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    91
            &mask.i((.., .., ..sz_t, ..sz_t)).eq(0.),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    92
            std::f64::NEG_INFINITY,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    93
        );
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    94
        let att = att.softmax(-1, Kind::Float).dropout(cfg.attn_pdrop, train);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    95
        let ys = att
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    96
            .matmul(&v)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    97
            .transpose(1, 2)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    98
            .contiguous()
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
    99
            .view([sz_b, sz_t, sz_c]);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   100
        ys.apply(&proj).dropout(cfg.resid_pdrop, train)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   101
    })
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   102
}
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   103
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   104
fn block(p: &nn::Path, cfg: Config) -> impl ModuleT {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   105
    let ln1 = nn::layer_norm(p / "ln1", vec![cfg.n_embd], Default::default());
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   106
    let ln2 = nn::layer_norm(p / "ln2", vec![cfg.n_embd], Default::default());
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   107
    let attn = causal_self_attention(p, cfg);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   108
    let lin1 = linear(p / "lin1", cfg.n_embd, 4 * cfg.n_embd);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   109
    let lin2 = linear(p / "lin2", 4 * cfg.n_embd, cfg.n_embd);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   110
    nn::func_t(move |xs, train| {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   111
        let xs = xs + xs.apply(&ln1).apply_t(&attn, train);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   112
        let ys = xs
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   113
            .apply(&ln2)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   114
            .apply(&lin1)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   115
            .gelu()
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   116
            .apply(&lin2)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   117
            .dropout(cfg.resid_pdrop, train);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   118
        xs + ys
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   119
    })
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   120
}
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   121
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   122
fn gpt(p: &nn::Path, cfg: Config) -> impl ModuleT {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   123
    let p = &p.set_group(NO_WEIGHT_DECAY_GROUP);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   124
    let tok_emb = nn::embedding(
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   125
        p / "tok_emb",
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   126
        cfg.vocab_size,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   127
        cfg.n_embd,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   128
        Default::default(),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   129
    );
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   130
    let pos_emb = p.zeros("pos_emb", &[1, cfg.block_size, cfg.n_embd]);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   131
    let ln_f = nn::layer_norm(p / "ln_f", vec![cfg.n_embd], Default::default());
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   132
    let head = linear_no_bias(p / "head", cfg.n_embd, cfg.vocab_size);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   133
    let mut blocks = nn::seq_t();
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   134
    for block_idx in 0..cfg.n_layer {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   135
        blocks = blocks.add(block(&(p / block_idx), cfg));
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   136
    }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   137
    nn::func_t(move |xs, train| {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   138
        let (_sz_b, sz_t) = xs.size2().unwrap();
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   139
        let tok_emb = xs.apply(&tok_emb);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   140
        let pos_emb = pos_emb.i((.., ..sz_t, ..));
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   141
        (tok_emb + pos_emb)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   142
            .dropout(cfg.embd_pdrop, train)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   143
            .apply_t(&blocks, train)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   144
            .apply(&ln_f)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   145
            .apply(&head)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   146
    })
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   147
}
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   148
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   149
/// Generates some sample string using the GPT model.
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   150
fn sample(data: &TextData, gpt: &impl ModuleT, input: Tensor) -> String {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   151
    let mut input = input;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   152
    let mut result = String::new();
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   153
    for _index in 0..SAMPLING_LEN {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   154
        let logits = input.apply_t(gpt, false).i((0, -1, ..));
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   155
        let sampled_y = logits.softmax(-1, Kind::Float).multinomial(1, true);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   156
        let last_label = i64::from(&sampled_y);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   157
        result.push(data.label_to_char(last_label));
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   158
        input = Tensor::cat(&[input, sampled_y.view([1, 1])], 1).narrow(1, 1, BLOCK_SIZE);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   159
    }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   160
    result
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   161
}
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   162
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   163
#[tokio::main]
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   164
async fn main() -> AHResult<()> {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   165
    let device = Device::cuda_if_available();
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   166
    let mut vs = nn::VarStore::new(device);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   167
    let data = TextData::new("10.log")?;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   168
    let labels = data.labels();
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   169
    println!("Dataset loaded, {} labels.", labels);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   170
    let cfg = Config {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   171
        vocab_size: labels,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   172
        n_embd: 384, // was 512
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   173
        n_head: 8,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   174
        n_layer: 8,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   175
        block_size: BLOCK_SIZE,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   176
        attn_pdrop: 0.1,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   177
        resid_pdrop: 0.1,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   178
        embd_pdrop: 0.1,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   179
    };
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   180
    let gpt = gpt(&(&vs.root() / "gpt"), cfg);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   181
    let args: Vec<_> = std::env::args().collect();
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   182
    if args.len() < 2 {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   183
        bail!("usage: main (train|predict weights.ot seqstart)")
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   184
    }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   185
    match args[1].as_str() {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   186
        "train" => {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   187
            let mut opt = nn::AdamW::default().build(&vs, LEARNING_RATE)?;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   188
            opt.set_weight_decay_group(NO_WEIGHT_DECAY_GROUP, 0.0);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   189
            opt.set_weight_decay_group(WEIGHT_DECAY_GROUP, 0.1);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   190
            let mut idx = 0;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   191
            vs.load("384.ot")?;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   192
            for epoch in 1..(1 + EPOCHS) {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   193
                let mut sum_loss = 0.;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   194
                let mut cnt_loss = 0.;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   195
                for batch in data.iter_shuffle(BLOCK_SIZE + 1, BATCH_SIZE) {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   196
                    let xs = batch
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   197
                        .narrow(1, 0, BLOCK_SIZE)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   198
                        .to_kind(Kind::Int64)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   199
                        .to_device(device);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   200
                    let ys = batch
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   201
                        .narrow(1, 1, BLOCK_SIZE)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   202
                        .to_kind(Kind::Int64)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   203
                        .to_device(device);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   204
                    let logits = xs.apply_t(&gpt, true);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   205
                    let loss = logits
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   206
                        .view([BATCH_SIZE * BLOCK_SIZE, labels])
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   207
                        .cross_entropy_for_logits(&ys.view([BATCH_SIZE * BLOCK_SIZE]));
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   208
                    opt.backward_step_clip(&loss, 0.5);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   209
                    sum_loss += f64::from(loss);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   210
                    cnt_loss += 1.0;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   211
                    idx += 1;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   212
                    if idx % 10 == 0 {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   213
                        print!("{}", '.');
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   214
                        io::stdout().flush()?;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   215
                    }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   216
                    if idx % 1000 == 0 {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   217
                        println!("Epoch: {}   loss: {:5.3}", epoch, sum_loss / cnt_loss);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   218
                        let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device));
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   219
                        println!("Sample: {}", sample(&data, &gpt, input));
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   220
                        if let Err(err) = vs.save(format!("gpt{:08}.ot", idx)) {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   221
                            println!("error while saving {}", err);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   222
                        }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   223
                        sum_loss = 0.;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   224
                        cnt_loss = 0.;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   225
                    }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   226
                }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   227
            }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   228
        }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   229
        "predict" => {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   230
            let amqp_url = std::env::var("AMQP_URL").expect("expected AMQP_URL env variabe");
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   231
            let conn = Connection::connect(&amqp_url, ConnectionProperties::default().with_tokio())
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   232
                .await?;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   233
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   234
            let pub_channel = conn.create_channel().await?;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   235
            let sub_channel = conn.create_channel().await?;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   236
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   237
            let queue = sub_channel
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   238
                .queue_declare(
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   239
                    &"",
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   240
                    QueueDeclareOptions {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   241
                        exclusive: true,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   242
                        auto_delete: true,
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   243
                        ..QueueDeclareOptions::default()
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   244
                    },
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   245
                    FieldTable::default(),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   246
                )
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   247
                .await?;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   248
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   249
            sub_channel
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   250
                .queue_bind(
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   251
                    queue.name().as_str(),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   252
                    "irc",
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   253
                    "cmd.say.hedgewars",
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   254
                    QueueBindOptions::default(),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   255
                    FieldTable::default(),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   256
                )
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   257
                .await?;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   258
15793
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   259
            sub_channel
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   260
                .queue_bind(
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   261
                    queue.name().as_str(),
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   262
                    "irc",
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   263
                    "msg.hedgewars",
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   264
                    QueueBindOptions::default(),
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   265
                    FieldTable::default(),
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   266
                )
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   267
                .await?;
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   268
15791
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   269
            let mut subscriber = sub_channel
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   270
                .basic_consume(
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   271
                    queue.name().as_str(),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   272
                    &"",
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   273
                    BasicConsumeOptions::default(),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   274
                    FieldTable::default(),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   275
                )
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   276
                .await?;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   277
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   278
            vs.load(args[2].as_str())?;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   279
15793
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   280
            let mut buffer = Vec::new();
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   281
15791
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   282
            while let Some(amqp_message) = subscriber.next().await {
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   283
                let (_, delivery) = amqp_message.expect("error in consumer");
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   284
                delivery.ack(BasicAckOptions::default()).await?;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   285
15793
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   286
                if delivery.routing_key.as_str() == "msg.hedgewars" {
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   287
                    let chat_message = String::from_utf8_lossy(&delivery.data);
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   288
                    if let Some((_who, message)) = chat_message.split_once('\n') {
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   289
                        buffer.push('\n');
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   290
                        buffer.extend(message.chars());
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   291
                        if buffer.len() >= BLOCK_SIZE as usize {
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   292
                            let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize);
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   293
                        }
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   294
                    }
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   295
                } else {
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   296
                    let chat_message = String::from_utf8_lossy(&delivery.data);
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   297
                    let seed = chat_message.split_once('\n').map(|(_, s)| s).unwrap_or("");
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   298
                    buffer.push('\n');
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   299
                    buffer.extend(seed.chars());
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   300
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   301
                    if buffer.len() >= BLOCK_SIZE as usize {
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   302
                        let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize);
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   303
                    }
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   304
15791
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   305
                    let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device));
15793
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   306
                    for (idx, c) in buffer.iter().rev().enumerate() {
15791
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   307
                        let _filled = input
15793
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   308
                            .i((0, BLOCK_SIZE - 1 - idx as i64))
96443d9b48c9 Update mingpt plugin to include recent chat history in the seed
unc0rr
parents: 15791
diff changeset
   309
                            .fill_(data.char_to_label(*c).unwrap_or(0) as i64);
15791
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   310
                    }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   311
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   312
                    let proceeded_message = &sample(&data, &gpt, input);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   313
                    let final_message = proceeded_message
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   314
                        .split_once('\n')
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   315
                        .map(|(m, _)| m)
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   316
                        .unwrap_or(proceeded_message);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   317
                    let final_message = &format!("{}{}", seed, final_message);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   318
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   319
                    println!("{} --> {}", seed, proceeded_message);
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   320
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   321
                    pub_channel
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   322
                        .basic_publish(
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   323
                            "irc",
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   324
                            "say.hedgewars",
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   325
                            BasicPublishOptions::default(),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   326
                            final_message.as_bytes().to_vec(),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   327
                            BasicProperties::default(),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   328
                        )
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   329
                        .await?;
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   330
                }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   331
            }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   332
        }
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   333
        _ => bail!("usage: main (train|predict weights.ot)"),
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   334
    };
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   335
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   336
    Ok(())
2528e3508bf4 Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff changeset
   337
}