--- a/tools/ubot-plugins/ubot-mingpt-plugin/src/main.rs Tue Jun 15 20:45:46 2021 +0200
+++ b/tools/ubot-plugins/ubot-mingpt-plugin/src/main.rs Tue Jun 15 20:46:28 2021 +0200
@@ -256,6 +256,16 @@
)
.await?;
+ sub_channel
+ .queue_bind(
+ queue.name().as_str(),
+ "irc",
+ "msg.hedgewars",
+ QueueBindOptions::default(),
+ FieldTable::default(),
+ )
+ .await?;
+
let mut subscriber = sub_channel
.basic_consume(
queue.name().as_str(),
@@ -267,22 +277,36 @@
vs.load(args[2].as_str())?;
+ let mut buffer = Vec::new();
+
while let Some(amqp_message) = subscriber.next().await {
let (_, delivery) = amqp_message.expect("error in consumer");
delivery.ack(BasicAckOptions::default()).await?;
- let chat_message = String::from_utf8(delivery.data)?;
- if let Some((_who, seed)) = chat_message.split_once('\n') {
- let input_sample = &format!("\n{}", seed);
+ if delivery.routing_key.as_str() == "msg.hedgewars" {
+ let chat_message = String::from_utf8_lossy(&delivery.data);
+ if let Some((_who, message)) = chat_message.split_once('\n') {
+ buffer.push('\n');
+ buffer.extend(message.chars());
+ if buffer.len() >= BLOCK_SIZE as usize {
+ let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize);
+ }
+ }
+ } else {
+ let chat_message = String::from_utf8_lossy(&delivery.data);
+ let seed = chat_message.split_once('\n').map(|(_, s)| s).unwrap_or("");
+ buffer.push('\n');
+ buffer.extend(seed.chars());
+
+ if buffer.len() >= BLOCK_SIZE as usize {
+ let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize);
+ }
+
let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device));
- for (idx, c) in input_sample.chars().rev().enumerate() {
- let idx = idx as i64;
- if idx >= BLOCK_SIZE {
- break;
- }
+ for (idx, c) in buffer.iter().rev().enumerate() {
let _filled = input
- .i((0, BLOCK_SIZE - 1 - idx))
- .fill_(data.char_to_label(c).unwrap_or(0) as i64);
+ .i((0, BLOCK_SIZE - 1 - idx as i64))
+ .fill_(data.char_to_label(*c).unwrap_or(0) as i64);
}
let proceeded_message = &sample(&data, &gpt, input);