Update mingpt plugin to include recent chat history in the seed
authorunc0rr
Tue, 15 Jun 2021 20:46:28 +0200
changeset 15815 96443d9b48c9
parent 15814 191e51179d1b
child 15816 7598960819a1
Update mingpt plugin to include recent chat history in the seed
tools/ubot-plugins/ubot-mingpt-plugin/src/main.rs
--- a/tools/ubot-plugins/ubot-mingpt-plugin/src/main.rs	Tue Jun 15 20:45:46 2021 +0200
+++ b/tools/ubot-plugins/ubot-mingpt-plugin/src/main.rs	Tue Jun 15 20:46:28 2021 +0200
@@ -256,6 +256,16 @@
                 )
                 .await?;
 
+            sub_channel
+                .queue_bind(
+                    queue.name().as_str(),
+                    "irc",
+                    "msg.hedgewars",
+                    QueueBindOptions::default(),
+                    FieldTable::default(),
+                )
+                .await?;
+
             let mut subscriber = sub_channel
                 .basic_consume(
                     queue.name().as_str(),
@@ -267,22 +277,36 @@
 
             vs.load(args[2].as_str())?;
 
+            let mut buffer = Vec::new();
+
             while let Some(amqp_message) = subscriber.next().await {
                 let (_, delivery) = amqp_message.expect("error in consumer");
                 delivery.ack(BasicAckOptions::default()).await?;
 
-                let chat_message = String::from_utf8(delivery.data)?;
-                if let Some((_who, seed)) = chat_message.split_once('\n') {
-                    let input_sample = &format!("\n{}", seed);
+                if delivery.routing_key.as_str() == "msg.hedgewars" {
+                    let chat_message = String::from_utf8_lossy(&delivery.data);
+                    if let Some((_who, message)) = chat_message.split_once('\n') {
+                        buffer.push('\n');
+                        buffer.extend(message.chars());
+                        if buffer.len() >= BLOCK_SIZE as usize {
+                            let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize);
+                        }
+                    }
+                } else {
+                    let chat_message = String::from_utf8_lossy(&delivery.data);
+                    let seed = chat_message.split_once('\n').map(|(_, s)| s).unwrap_or("");
+                    buffer.push('\n');
+                    buffer.extend(seed.chars());
+
+                    if buffer.len() >= BLOCK_SIZE as usize {
+                        let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize);
+                    }
+
                     let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device));
-                    for (idx, c) in input_sample.chars().rev().enumerate() {
-                        let idx = idx as i64;
-                        if idx >= BLOCK_SIZE {
-                            break;
-                        }
+                    for (idx, c) in buffer.iter().rev().enumerate() {
                         let _filled = input
-                            .i((0, BLOCK_SIZE - 1 - idx))
-                            .fill_(data.char_to_label(c).unwrap_or(0) as i64);
+                            .i((0, BLOCK_SIZE - 1 - idx as i64))
+                            .fill_(data.char_to_label(*c).unwrap_or(0) as i64);
                     }
 
                     let proceeded_message = &sample(&data, &gpt, input);