author | unC0Rr |
Fri, 03 Feb 2023 14:44:33 +0100 | |
branch | transitional_engine |
changeset 15916 | e82de0410da5 |
parent 15793 | 96443d9b48c9 |
permissions | -rw-r--r-- |
15791
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
1 |
/* This example uses the tinyshakespeare dataset which can be downloaded at: |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
2 |
https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
3 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
4 |
This is mostly a rust port of https://github.com/karpathy/minGPT |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
5 |
*/ |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
6 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
7 |
extern crate tch; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
8 |
use anyhow::{bail, Result as AHResult}; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
9 |
use std::{io, io::Write}; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
10 |
use tch::data::TextData; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
11 |
use tch::nn::{ModuleT, OptimizerConfig}; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
12 |
use tch::{nn, Device, IndexOp, Kind, Tensor}; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
13 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
14 |
use futures::prelude::*; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
15 |
use lapin::{options::*, types::FieldTable, BasicProperties, Connection, ConnectionProperties}; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
16 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
17 |
use tokio_amqp::*; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
18 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
19 |
const LEARNING_RATE: f64 = 0.0003; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
20 |
const BLOCK_SIZE: i64 = 128; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
21 |
const BATCH_SIZE: i64 = 64; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
22 |
const EPOCHS: i64 = 100; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
23 |
const SAMPLING_LEN: i64 = 512; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
24 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
25 |
#[derive(Debug, Copy, Clone)] |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
26 |
struct Config { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
27 |
vocab_size: i64, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
28 |
n_embd: i64, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
29 |
n_head: i64, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
30 |
n_layer: i64, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
31 |
block_size: i64, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
32 |
attn_pdrop: f64, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
33 |
resid_pdrop: f64, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
34 |
embd_pdrop: f64, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
35 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
36 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
37 |
// Weight decay only applies to the weight matrixes in the linear layers |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
38 |
const NO_WEIGHT_DECAY_GROUP: usize = 0; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
39 |
const WEIGHT_DECAY_GROUP: usize = 1; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
40 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
41 |
// Custom linear layer so that different groups can be used for weight |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
42 |
// and biases. |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
43 |
#[derive(Debug)] |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
44 |
struct Linear { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
45 |
pub ws: Tensor, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
46 |
pub bs: Tensor, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
47 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
48 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
49 |
impl nn::Module for Linear { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
50 |
fn forward(&self, xs: &Tensor) -> Tensor { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
51 |
xs.matmul(&self.ws.tr()) + &self.bs |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
52 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
53 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
54 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
55 |
fn linear(vs: nn::Path, in_dim: i64, out_dim: i64) -> Linear { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
56 |
let wd = vs.set_group(WEIGHT_DECAY_GROUP); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
57 |
let no_wd = vs.set_group(NO_WEIGHT_DECAY_GROUP); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
58 |
Linear { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
59 |
ws: wd.randn("weight", &[out_dim, in_dim], 0.0, 0.02), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
60 |
bs: no_wd.zeros("bias", &[out_dim]), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
61 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
62 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
63 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
64 |
fn linear_no_bias(vs: nn::Path, in_dim: i64, out_dim: i64) -> Linear { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
65 |
let wd = vs.set_group(WEIGHT_DECAY_GROUP); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
66 |
let no_wd = vs.set_group(NO_WEIGHT_DECAY_GROUP); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
67 |
Linear { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
68 |
ws: wd.randn("weight", &[out_dim, in_dim], 0.0, 0.02), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
69 |
bs: no_wd.zeros_no_train("bias", &[out_dim]), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
70 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
71 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
72 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
73 |
fn causal_self_attention(p: &nn::Path, cfg: Config) -> impl ModuleT { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
74 |
let key = linear(p / "key", cfg.n_embd, cfg.n_embd); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
75 |
let query = linear(p / "query", cfg.n_embd, cfg.n_embd); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
76 |
let value = linear(p / "value", cfg.n_embd, cfg.n_embd); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
77 |
let proj = linear(p / "proj", cfg.n_embd, cfg.n_embd); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
78 |
let mask_init = |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
79 |
Tensor::ones(&[cfg.block_size, cfg.block_size], (Kind::Float, p.device())).tril(0); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
80 |
let mask_init = mask_init.view([1, 1, cfg.block_size, cfg.block_size]); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
81 |
// let mask = p.var_copy("mask", &mask_init); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
82 |
let mask = mask_init; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
83 |
nn::func_t(move |xs, train| { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
84 |
let (sz_b, sz_t, sz_c) = xs.size3().unwrap(); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
85 |
let sizes = [sz_b, sz_t, cfg.n_head, sz_c / cfg.n_head]; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
86 |
let k = xs.apply(&key).view(sizes).transpose(1, 2); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
87 |
let q = xs.apply(&query).view(sizes).transpose(1, 2); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
88 |
let v = xs.apply(&value).view(sizes).transpose(1, 2); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
89 |
let att = q.matmul(&k.transpose(-2, -1)) * (1.0 / f64::sqrt(sizes[3] as f64)); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
90 |
let att = att.masked_fill( |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
91 |
&mask.i((.., .., ..sz_t, ..sz_t)).eq(0.), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
92 |
std::f64::NEG_INFINITY, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
93 |
); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
94 |
let att = att.softmax(-1, Kind::Float).dropout(cfg.attn_pdrop, train); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
95 |
let ys = att |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
96 |
.matmul(&v) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
97 |
.transpose(1, 2) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
98 |
.contiguous() |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
99 |
.view([sz_b, sz_t, sz_c]); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
100 |
ys.apply(&proj).dropout(cfg.resid_pdrop, train) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
101 |
}) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
102 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
103 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
104 |
fn block(p: &nn::Path, cfg: Config) -> impl ModuleT { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
105 |
let ln1 = nn::layer_norm(p / "ln1", vec![cfg.n_embd], Default::default()); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
106 |
let ln2 = nn::layer_norm(p / "ln2", vec![cfg.n_embd], Default::default()); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
107 |
let attn = causal_self_attention(p, cfg); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
108 |
let lin1 = linear(p / "lin1", cfg.n_embd, 4 * cfg.n_embd); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
109 |
let lin2 = linear(p / "lin2", 4 * cfg.n_embd, cfg.n_embd); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
110 |
nn::func_t(move |xs, train| { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
111 |
let xs = xs + xs.apply(&ln1).apply_t(&attn, train); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
112 |
let ys = xs |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
113 |
.apply(&ln2) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
114 |
.apply(&lin1) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
115 |
.gelu() |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
116 |
.apply(&lin2) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
117 |
.dropout(cfg.resid_pdrop, train); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
118 |
xs + ys |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
119 |
}) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
120 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
121 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
122 |
fn gpt(p: &nn::Path, cfg: Config) -> impl ModuleT { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
123 |
let p = &p.set_group(NO_WEIGHT_DECAY_GROUP); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
124 |
let tok_emb = nn::embedding( |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
125 |
p / "tok_emb", |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
126 |
cfg.vocab_size, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
127 |
cfg.n_embd, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
128 |
Default::default(), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
129 |
); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
130 |
let pos_emb = p.zeros("pos_emb", &[1, cfg.block_size, cfg.n_embd]); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
131 |
let ln_f = nn::layer_norm(p / "ln_f", vec![cfg.n_embd], Default::default()); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
132 |
let head = linear_no_bias(p / "head", cfg.n_embd, cfg.vocab_size); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
133 |
let mut blocks = nn::seq_t(); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
134 |
for block_idx in 0..cfg.n_layer { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
135 |
blocks = blocks.add(block(&(p / block_idx), cfg)); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
136 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
137 |
nn::func_t(move |xs, train| { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
138 |
let (_sz_b, sz_t) = xs.size2().unwrap(); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
139 |
let tok_emb = xs.apply(&tok_emb); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
140 |
let pos_emb = pos_emb.i((.., ..sz_t, ..)); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
141 |
(tok_emb + pos_emb) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
142 |
.dropout(cfg.embd_pdrop, train) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
143 |
.apply_t(&blocks, train) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
144 |
.apply(&ln_f) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
145 |
.apply(&head) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
146 |
}) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
147 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
148 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
149 |
/// Generates some sample string using the GPT model. |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
150 |
fn sample(data: &TextData, gpt: &impl ModuleT, input: Tensor) -> String { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
151 |
let mut input = input; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
152 |
let mut result = String::new(); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
153 |
for _index in 0..SAMPLING_LEN { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
154 |
let logits = input.apply_t(gpt, false).i((0, -1, ..)); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
155 |
let sampled_y = logits.softmax(-1, Kind::Float).multinomial(1, true); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
156 |
let last_label = i64::from(&sampled_y); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
157 |
result.push(data.label_to_char(last_label)); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
158 |
input = Tensor::cat(&[input, sampled_y.view([1, 1])], 1).narrow(1, 1, BLOCK_SIZE); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
159 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
160 |
result |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
161 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
162 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
163 |
#[tokio::main] |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
164 |
async fn main() -> AHResult<()> { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
165 |
let device = Device::cuda_if_available(); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
166 |
let mut vs = nn::VarStore::new(device); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
167 |
let data = TextData::new("10.log")?; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
168 |
let labels = data.labels(); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
169 |
println!("Dataset loaded, {} labels.", labels); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
170 |
let cfg = Config { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
171 |
vocab_size: labels, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
172 |
n_embd: 384, // was 512 |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
173 |
n_head: 8, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
174 |
n_layer: 8, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
175 |
block_size: BLOCK_SIZE, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
176 |
attn_pdrop: 0.1, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
177 |
resid_pdrop: 0.1, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
178 |
embd_pdrop: 0.1, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
179 |
}; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
180 |
let gpt = gpt(&(&vs.root() / "gpt"), cfg); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
181 |
let args: Vec<_> = std::env::args().collect(); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
182 |
if args.len() < 2 { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
183 |
bail!("usage: main (train|predict weights.ot seqstart)") |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
184 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
185 |
match args[1].as_str() { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
186 |
"train" => { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
187 |
let mut opt = nn::AdamW::default().build(&vs, LEARNING_RATE)?; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
188 |
opt.set_weight_decay_group(NO_WEIGHT_DECAY_GROUP, 0.0); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
189 |
opt.set_weight_decay_group(WEIGHT_DECAY_GROUP, 0.1); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
190 |
let mut idx = 0; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
191 |
vs.load("384.ot")?; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
192 |
for epoch in 1..(1 + EPOCHS) { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
193 |
let mut sum_loss = 0.; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
194 |
let mut cnt_loss = 0.; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
195 |
for batch in data.iter_shuffle(BLOCK_SIZE + 1, BATCH_SIZE) { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
196 |
let xs = batch |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
197 |
.narrow(1, 0, BLOCK_SIZE) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
198 |
.to_kind(Kind::Int64) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
199 |
.to_device(device); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
200 |
let ys = batch |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
201 |
.narrow(1, 1, BLOCK_SIZE) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
202 |
.to_kind(Kind::Int64) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
203 |
.to_device(device); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
204 |
let logits = xs.apply_t(&gpt, true); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
205 |
let loss = logits |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
206 |
.view([BATCH_SIZE * BLOCK_SIZE, labels]) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
207 |
.cross_entropy_for_logits(&ys.view([BATCH_SIZE * BLOCK_SIZE])); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
208 |
opt.backward_step_clip(&loss, 0.5); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
209 |
sum_loss += f64::from(loss); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
210 |
cnt_loss += 1.0; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
211 |
idx += 1; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
212 |
if idx % 10 == 0 { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
213 |
print!("{}", '.'); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
214 |
io::stdout().flush()?; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
215 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
216 |
if idx % 1000 == 0 { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
217 |
println!("Epoch: {} loss: {:5.3}", epoch, sum_loss / cnt_loss); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
218 |
let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device)); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
219 |
println!("Sample: {}", sample(&data, &gpt, input)); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
220 |
if let Err(err) = vs.save(format!("gpt{:08}.ot", idx)) { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
221 |
println!("error while saving {}", err); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
222 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
223 |
sum_loss = 0.; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
224 |
cnt_loss = 0.; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
225 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
226 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
227 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
228 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
229 |
"predict" => { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
230 |
let amqp_url = std::env::var("AMQP_URL").expect("expected AMQP_URL env variabe"); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
231 |
let conn = Connection::connect(&amqp_url, ConnectionProperties::default().with_tokio()) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
232 |
.await?; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
233 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
234 |
let pub_channel = conn.create_channel().await?; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
235 |
let sub_channel = conn.create_channel().await?; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
236 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
237 |
let queue = sub_channel |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
238 |
.queue_declare( |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
239 |
&"", |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
240 |
QueueDeclareOptions { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
241 |
exclusive: true, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
242 |
auto_delete: true, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
243 |
..QueueDeclareOptions::default() |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
244 |
}, |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
245 |
FieldTable::default(), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
246 |
) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
247 |
.await?; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
248 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
249 |
sub_channel |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
250 |
.queue_bind( |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
251 |
queue.name().as_str(), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
252 |
"irc", |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
253 |
"cmd.say.hedgewars", |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
254 |
QueueBindOptions::default(), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
255 |
FieldTable::default(), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
256 |
) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
257 |
.await?; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
258 |
|
15793
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
259 |
sub_channel |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
260 |
.queue_bind( |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
261 |
queue.name().as_str(), |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
262 |
"irc", |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
263 |
"msg.hedgewars", |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
264 |
QueueBindOptions::default(), |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
265 |
FieldTable::default(), |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
266 |
) |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
267 |
.await?; |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
268 |
|
15791
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
269 |
let mut subscriber = sub_channel |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
270 |
.basic_consume( |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
271 |
queue.name().as_str(), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
272 |
&"", |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
273 |
BasicConsumeOptions::default(), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
274 |
FieldTable::default(), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
275 |
) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
276 |
.await?; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
277 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
278 |
vs.load(args[2].as_str())?; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
279 |
|
15793
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
280 |
let mut buffer = Vec::new(); |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
281 |
|
15791
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
282 |
while let Some(amqp_message) = subscriber.next().await { |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
283 |
let (_, delivery) = amqp_message.expect("error in consumer"); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
284 |
delivery.ack(BasicAckOptions::default()).await?; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
285 |
|
15793
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
286 |
if delivery.routing_key.as_str() == "msg.hedgewars" { |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
287 |
let chat_message = String::from_utf8_lossy(&delivery.data); |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
288 |
if let Some((_who, message)) = chat_message.split_once('\n') { |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
289 |
buffer.push('\n'); |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
290 |
buffer.extend(message.chars()); |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
291 |
if buffer.len() >= BLOCK_SIZE as usize { |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
292 |
let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize); |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
293 |
} |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
294 |
} |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
295 |
} else { |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
296 |
let chat_message = String::from_utf8_lossy(&delivery.data); |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
297 |
let seed = chat_message.split_once('\n').map(|(_, s)| s).unwrap_or(""); |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
298 |
buffer.push('\n'); |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
299 |
buffer.extend(seed.chars()); |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
300 |
|
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
301 |
if buffer.len() >= BLOCK_SIZE as usize { |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
302 |
let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize); |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
303 |
} |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
304 |
|
15791
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
305 |
let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device)); |
15793
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
306 |
for (idx, c) in buffer.iter().rev().enumerate() { |
15791
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
307 |
let _filled = input |
15793
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
308 |
.i((0, BLOCK_SIZE - 1 - idx as i64)) |
96443d9b48c9
Update mingpt plugin to include recent chat history in the seed
unc0rr
parents:
15791
diff
changeset
|
309 |
.fill_(data.char_to_label(*c).unwrap_or(0) as i64); |
15791
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
310 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
311 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
312 |
let proceeded_message = &sample(&data, &gpt, input); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
313 |
let final_message = proceeded_message |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
314 |
.split_once('\n') |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
315 |
.map(|(m, _)| m) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
316 |
.unwrap_or(proceeded_message); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
317 |
let final_message = &format!("{}{}", seed, final_message); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
318 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
319 |
println!("{} --> {}", seed, proceeded_message); |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
320 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
321 |
pub_channel |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
322 |
.basic_publish( |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
323 |
"irc", |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
324 |
"say.hedgewars", |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
325 |
BasicPublishOptions::default(), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
326 |
final_message.as_bytes().to_vec(), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
327 |
BasicProperties::default(), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
328 |
) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
329 |
.await?; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
330 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
331 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
332 |
} |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
333 |
_ => bail!("usage: main (train|predict weights.ot)"), |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
334 |
}; |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
335 |
|
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
336 |
Ok(()) |
2528e3508bf4
Add mingpt plugin which is mingpt example from tch repo plus amqp wrapper by me
unc0rr
parents:
diff
changeset
|
337 |
} |