libsurfer/cxxrtl/
io_worker.rs
1use std::{collections::VecDeque, io::Write};
2
3use eyre::{Context, Result};
4use log::{error, info, trace};
5use tokio::{
6 io::{AsyncReadExt, AsyncWriteExt},
7 sync::mpsc,
8};
9
10use crate::channels::IngressSender;
11
12pub struct CxxrtlWorker<W, R> {
13 write: W,
14 read: R,
15 read_buf: VecDeque<u8>,
16
17 sc_channel: IngressSender<String>,
18 cs_channel: mpsc::Receiver<String>,
19}
20
21impl<W, R> CxxrtlWorker<W, R>
22where
23 W: AsyncWriteExt + Unpin,
24 R: AsyncReadExt + Unpin,
25{
26 pub(crate) fn new(
27 write: W,
28 read: R,
29 sc_channel: IngressSender<String>,
30 cs_channel: mpsc::Receiver<String>,
31 ) -> Self {
32 Self {
33 write,
34 read,
35 read_buf: VecDeque::new(),
36 sc_channel,
37 cs_channel,
38 }
39 }
40
41 pub(crate) async fn start(mut self) {
42 info!("cxxrtl worker is up-and-running");
43 let mut buf = [0; 1024];
44 loop {
45 tokio::select! {
46 rx = self.cs_channel.recv() => {
47 if let Some(msg) = rx {
48 if let Err(e) = self.send_message(msg).await {
49 error!("Failed to send message {e:#?}");
50 }
51 }
52 }
53 count = self.read.read(&mut buf) => {
54 match count {
55 Ok(count) => {
56 trace!("CXXRTL Read {count} from reader");
57 match self.process_stream(count, &mut buf).await {
58 Ok(msgs) => {
59 for msg in msgs {
60 self.sc_channel.send(msg).await.unwrap();
61 }
62 }
63 Err(e) => {
64 error!("Failed to process cxxrtl message ({e:#?})");
65 }
66 }
67 },
68 Err(e) => {
69 error!("Failed to read bytes from cxxrtl {e:#?}. Shutting down client");
70 break;
71 }
72 }
73 }
74 }
75 }
76 }
77
78 async fn process_stream(&mut self, count: usize, buf: &mut [u8; 1024]) -> Result<Vec<String>> {
79 if count != 0 {
80 self.read_buf
81 .write_all(&buf[0..count])
82 .context("Failed to read from cxxrtl tcp socket")?;
83 }
84
85 let mut new_messages = vec![];
86
87 while let Some(idx) = self
88 .read_buf
89 .iter()
90 .enumerate()
91 .find(|(_i, c)| **c == b'\0')
92 {
93 let message = self.read_buf.drain(0..idx.0).collect::<Vec<_>>();
94 self.read_buf.pop_front();
96
97 new_messages
98 .push(String::from_utf8(message).context("Got non-utf8 characters from cxxrtl")?)
99 }
100
101 Ok(new_messages)
102 }
103
104 async fn send_message(&mut self, message: String) -> Result<()> {
105 self.write.write_all(message.as_bytes()).await?;
106 self.write.write_all(b"\0").await?;
107 self.write.flush().await?;
108
109 Ok(())
110 }
111}