libsurfer/cxxrtl/
io_worker.rs

1use std::{collections::VecDeque, io::Write};
2
3use eyre::{Context, Result};
4use log::{error, info, trace};
5use tokio::{
6    io::{AsyncReadExt, AsyncWriteExt},
7    sync::mpsc,
8};
9
10use crate::channels::IngressSender;
11
12pub struct CxxrtlWorker<W, R> {
13    write: W,
14    read: R,
15    read_buf: VecDeque<u8>,
16
17    sc_channel: IngressSender<String>,
18    cs_channel: mpsc::Receiver<String>,
19}
20
21impl<W, R> CxxrtlWorker<W, R>
22where
23    W: AsyncWriteExt + Unpin,
24    R: AsyncReadExt + Unpin,
25{
26    pub(crate) fn new(
27        write: W,
28        read: R,
29        sc_channel: IngressSender<String>,
30        cs_channel: mpsc::Receiver<String>,
31    ) -> Self {
32        Self {
33            write,
34            read,
35            read_buf: VecDeque::new(),
36            sc_channel,
37            cs_channel,
38        }
39    }
40
41    pub(crate) async fn start(mut self) {
42        info!("cxxrtl worker is up-and-running");
43        let mut buf = [0; 1024];
44        loop {
45            tokio::select! {
46                rx = self.cs_channel.recv() => {
47                    if let Some(msg) = rx {
48                        if let Err(e) =  self.send_message(msg).await {
49                            error!("Failed to send message {e:#?}");
50                        }
51                    }
52                }
53                count = self.read.read(&mut buf) => {
54                    match count {
55                        Ok(count) => {
56                            trace!("CXXRTL Read {count} from reader");
57                            match self.process_stream(count, &mut buf).await {
58                                Ok(msgs) => {
59                                    for msg in msgs {
60                                        self.sc_channel.send(msg).await.unwrap();
61                                    }
62                                }
63                                Err(e) => {
64                                    error!("Failed to process cxxrtl message ({e:#?})");
65                                }
66                            }
67                        },
68                        Err(e) => {
69                            error!("Failed to read bytes from cxxrtl {e:#?}. Shutting down client");
70                            break;
71                        }
72                    }
73                }
74            }
75        }
76    }
77
78    async fn process_stream(&mut self, count: usize, buf: &mut [u8; 1024]) -> Result<Vec<String>> {
79        if count != 0 {
80            self.read_buf
81                .write_all(&buf[0..count])
82                .context("Failed to read from cxxrtl tcp socket")?;
83        }
84
85        let mut new_messages = vec![];
86
87        while let Some(idx) = self
88            .read_buf
89            .iter()
90            .enumerate()
91            .find(|(_i, c)| **c == b'\0')
92        {
93            let message = self.read_buf.drain(0..idx.0).collect::<Vec<_>>();
94            // The null byte should not be part of this or the next message message
95            self.read_buf.pop_front();
96
97            new_messages
98                .push(String::from_utf8(message).context("Got non-utf8 characters from cxxrtl")?)
99        }
100
101        Ok(new_messages)
102    }
103
104    async fn send_message(&mut self, message: String) -> Result<()> {
105        self.write.write_all(message.as_bytes()).await?;
106        self.write.write_all(b"\0").await?;
107        self.write.flush().await?;
108
109        Ok(())
110    }
111}