libsurfer/cxxrtl/
io_worker.rs1use std::{collections::VecDeque, io::Write};
2
3use eyre::{Context, Result};
4use tokio::{
5 io::{AsyncReadExt, AsyncWriteExt},
6 sync::mpsc,
7};
8use tracing::{error, info, trace};
9
10use crate::channels::IngressSender;
11
12pub struct CxxrtlWorker<W, R> {
13 write: W,
14 read: R,
15 read_buf: VecDeque<u8>,
16
17 sc_channel: IngressSender<String>,
18 cs_channel: mpsc::Receiver<String>,
19}
20
21impl<W, R> CxxrtlWorker<W, R>
22where
23 W: AsyncWriteExt + Unpin,
24 R: AsyncReadExt + Unpin,
25{
26 pub(crate) fn new(
27 write: W,
28 read: R,
29 sc_channel: IngressSender<String>,
30 cs_channel: mpsc::Receiver<String>,
31 ) -> Self {
32 Self {
33 write,
34 read,
35 read_buf: VecDeque::new(),
36 sc_channel,
37 cs_channel,
38 }
39 }
40
41 pub(crate) async fn start(mut self) {
42 info!("cxxrtl worker is up-and-running");
43 let mut buf = [0; 1024];
44 loop {
45 tokio::select! {
46 rx = self.cs_channel.recv() => {
47 if let Some(msg) = rx
48 && let Err(e) = self.send_message(msg).await {
49 error!("Failed to send message {e:#?}");
50 }
51 }
52 count = self.read.read(&mut buf) => {
53 match count {
54 Ok(count) => {
55 trace!("CXXRTL Read {count} from reader");
56 match self.process_stream(count, &mut buf).await {
57 Ok(msgs) => {
58 for msg in msgs {
59 self.sc_channel.send(msg).await.unwrap();
60 }
61 }
62 Err(e) => {
63 error!("Failed to process cxxrtl message ({e:#?})");
64 }
65 }
66 },
67 Err(e) => {
68 error!("Failed to read bytes from cxxrtl {e:#?}. Shutting down client");
69 break;
70 }
71 }
72 }
73 }
74 }
75 }
76
77 async fn process_stream(&mut self, count: usize, buf: &mut [u8; 1024]) -> Result<Vec<String>> {
78 if count != 0 {
79 self.read_buf
80 .write_all(&buf[0..count])
81 .context("Failed to read from cxxrtl tcp socket")?;
82 }
83
84 let mut new_messages = vec![];
85
86 while let Some(idx) = self
87 .read_buf
88 .iter()
89 .enumerate()
90 .find(|(_i, c)| **c == b'\0')
91 {
92 let message = self.read_buf.drain(0..idx.0).collect::<Vec<_>>();
93 self.read_buf.pop_front();
95
96 new_messages
97 .push(String::from_utf8(message).context("Got non-utf8 characters from cxxrtl")?);
98 }
99
100 Ok(new_messages)
101 }
102
103 async fn send_message(&mut self, message: String) -> Result<()> {
104 self.write.write_all(message.as_bytes()).await?;
105 self.write.write_all(b"\0").await?;
106 self.write.flush().await?;
107
108 Ok(())
109 }
110}