Skip to main content

libsurfer/cxxrtl/
io_worker.rs

1use std::{collections::VecDeque, io::Write};
2
3use eyre::{Context, Result};
4use tokio::{
5    io::{AsyncReadExt, AsyncWriteExt},
6    sync::mpsc,
7};
8use tracing::{error, info, trace};
9
10use crate::channels::IngressSender;
11
12pub struct CxxrtlWorker<W, R> {
13    write: W,
14    read: R,
15    read_buf: VecDeque<u8>,
16
17    sc_channel: IngressSender<String>,
18    cs_channel: mpsc::Receiver<String>,
19}
20
21impl<W, R> CxxrtlWorker<W, R>
22where
23    W: AsyncWriteExt + Unpin,
24    R: AsyncReadExt + Unpin,
25{
26    pub(crate) fn new(
27        write: W,
28        read: R,
29        sc_channel: IngressSender<String>,
30        cs_channel: mpsc::Receiver<String>,
31    ) -> Self {
32        Self {
33            write,
34            read,
35            read_buf: VecDeque::new(),
36            sc_channel,
37            cs_channel,
38        }
39    }
40
41    pub(crate) async fn start(mut self) {
42        info!("cxxrtl worker is up-and-running");
43        let mut buf = [0; 1024];
44        loop {
45            tokio::select! {
46                rx = self.cs_channel.recv() => {
47                    if let Some(msg) = rx
48                        && let Err(e) =  self.send_message(msg).await {
49                            error!("Failed to send message {e:#?}");
50                        }
51                }
52                count = self.read.read(&mut buf) => {
53                    match count {
54                        Ok(count) => {
55                            trace!("CXXRTL Read {count} from reader");
56                            match self.process_stream(count, &mut buf).await {
57                                Ok(msgs) => {
58                                    for msg in msgs {
59                                        self.sc_channel.send(msg).await.unwrap();
60                                    }
61                                }
62                                Err(e) => {
63                                    error!("Failed to process cxxrtl message ({e:#?})");
64                                }
65                            }
66                        },
67                        Err(e) => {
68                            error!("Failed to read bytes from cxxrtl {e:#?}. Shutting down client");
69                            break;
70                        }
71                    }
72                }
73            }
74        }
75    }
76
77    async fn process_stream(&mut self, count: usize, buf: &mut [u8; 1024]) -> Result<Vec<String>> {
78        if count != 0 {
79            self.read_buf
80                .write_all(&buf[0..count])
81                .context("Failed to read from cxxrtl tcp socket")?;
82        }
83
84        let mut new_messages = vec![];
85
86        while let Some(idx) = self
87            .read_buf
88            .iter()
89            .enumerate()
90            .find(|(_i, c)| **c == b'\0')
91        {
92            let message = self.read_buf.drain(0..idx.0).collect::<Vec<_>>();
93            // The null byte should not be part of this or the next message message
94            self.read_buf.pop_front();
95
96            new_messages
97                .push(String::from_utf8(message).context("Got non-utf8 characters from cxxrtl")?);
98        }
99
100        Ok(new_messages)
101    }
102
103    async fn send_message(&mut self, message: String) -> Result<()> {
104        self.write.write_all(message.as_bytes()).await?;
105        self.write.write_all(b"\0").await?;
106        self.write.flush().await?;
107
108        Ok(())
109    }
110}