Skip to main content

libsurfer/remote/
client.rs

1use std::fmt::Write as _;
2use std::sync::Arc;
3use std::sync::OnceLock;
4use std::sync::mpsc::Sender;
5
6use bincode::Options;
7use eyre::{Context, Result, anyhow};
8use eyre::{bail, eyre};
9use reqwest::StatusCode;
10use thiserror::Error;
11use tracing::{error, info, warn};
12use wellen::CompressedTimeTable;
13
14use surver::{
15    BINCODE_OPTIONS, HTTP_SERVER_KEY, HTTP_SERVER_VALUE_SURFER, SURFER_VERSION, SurverStatus,
16    WELLEN_VERSION, X_SURFER_VERSION, X_WELLEN_VERSION,
17};
18
19use super::HierarchyResponse;
20use crate::async_util::{perform_async_work, sleep_ms};
21use crate::message::Message;
22use crate::wave_source::{LoadOptions, WaveSource};
23use crate::wellen::{BodyResult, HeaderResult};
24
25/// Returns a shared reqwest client to reuse HTTP connections and reduce TLS overhead.
26fn get_client() -> &'static reqwest::Client {
27    static CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
28    CLIENT.get_or_init(reqwest::Client::new)
29}
30
31#[derive(Debug, Error)]
32pub enum ReloadError {
33    #[error("File unchanged since last reload")]
34    FileUnchanged,
35    #[error("Unexpected response code: {0}")]
36    UnexpectedStatus(StatusCode),
37    #[error("Network error: {0}")]
38    Network(#[from] reqwest::Error),
39    #[error("Parse error: {0}")]
40    Parse(#[from] serde_json::Error),
41    #[error("Response validation error: {0}")]
42    Validation(#[from] eyre::Report),
43}
44
45fn check_response(server_url: &str, response: &reqwest::Response) -> Result<()> {
46    let server = response
47        .headers()
48        .get(HTTP_SERVER_KEY)
49        .ok_or(eyre!("no server header"))?
50        .to_str()?;
51    if server != HTTP_SERVER_VALUE_SURFER {
52        bail!("Unexpected server {server} from {server_url}");
53    }
54    let surfer_version = response
55        .headers()
56        .get(X_SURFER_VERSION)
57        .ok_or(eyre!("no surfer version header"))?
58        .to_str()?;
59    if surfer_version != SURFER_VERSION {
60        // this mismatch may be OK as long as the wellen version matches
61        info!(
62            "Surfer version on the server: {surfer_version} does not match client version {SURFER_VERSION}"
63        );
64    }
65    let wellen_version = response
66        .headers()
67        .get(X_WELLEN_VERSION)
68        .ok_or(eyre!("no wellen version header"))?
69        .to_str()?;
70    if wellen_version != WELLEN_VERSION {
71        bail!(
72            "Version incompatibility! The server uses wellen {wellen_version}, our client uses wellen {WELLEN_VERSION}"
73        );
74    }
75    Ok(())
76}
77
78async fn get_status(server: String) -> Result<SurverStatus> {
79    let client = get_client();
80    let response = client.get(format!("{server}/get_status")).send().await?;
81    check_response(&server, &response)?;
82    let body = response.text().await?;
83    let status = serde_json::from_str::<SurverStatus>(&body)?;
84    Ok(status)
85}
86
87async fn reload(
88    server: String,
89    file_index: usize,
90) -> std::result::Result<SurverStatus, ReloadError> {
91    let client = get_client();
92    let response = client
93        .get(format!("{server}/{file_index}/reload"))
94        .send()
95        .await?;
96    check_response(&server, &response)?;
97    let status_code = response.status();
98    let body = response.text().await?;
99    match status_code {
100        StatusCode::NOT_MODIFIED => {
101            info!("File unchanged, no reload needed");
102            Err(ReloadError::FileUnchanged)
103        }
104        StatusCode::ACCEPTED => {
105            info!("File reloaded at server");
106            let status = serde_json::from_str::<SurverStatus>(&body)?;
107            Ok(status)
108        }
109        code => {
110            warn!("Unexpected response code: {code}");
111            Err(ReloadError::UnexpectedStatus(code))
112        }
113    }
114}
115
116async fn get_hierarchy(server: String, file_index: usize) -> Result<HierarchyResponse> {
117    let client = get_client();
118    let response = client
119        .get(format!("{server}/{file_index}/get_hierarchy"))
120        .send()
121        .await?;
122    check_response(&server, &response)?;
123    let compressed = response.bytes().await?;
124    let raw = lz4_flex::decompress_size_prepended(&compressed)?;
125    let mut reader = std::io::Cursor::new(raw);
126    // first we read a value, expecting there to be more bytes
127    let opts = BINCODE_OPTIONS.allow_trailing_bytes();
128    let file_format: wellen::FileFormat = opts.deserialize_from(&mut reader)?;
129    // the last value should consume all remaining bytes
130    let hierarchy: wellen::Hierarchy = BINCODE_OPTIONS.deserialize_from(&mut reader)?;
131    Ok(HierarchyResponse {
132        hierarchy,
133        file_format,
134    })
135}
136
137async fn get_time_table(server: String, file_index: usize) -> Result<Vec<wellen::Time>> {
138    let client = get_client();
139    let response = client
140        .get(format!("{server}/{file_index}/get_time_table"))
141        .send()
142        .await?;
143    check_response(&server, &response)?;
144    let compressed_data = response.bytes().await?;
145    let compressed: CompressedTimeTable = BINCODE_OPTIONS.deserialize(&compressed_data)?;
146    let table = compressed.uncompress();
147    Ok(table)
148}
149
150// Helper to calculate URL length for a signal index
151// Much more efficient than string conversion
152// Extracted for testing
153#[inline]
154fn signal_url_len(index: usize) -> usize {
155    index.checked_ilog10().unwrap_or(0) as usize + 2 // +1 for '/', +1 as ilog10 rounds down
156}
157
158pub async fn get_signals(
159    server: String,
160    signals: &[wellen::SignalRef],
161    max_url_length: u16,
162    file_index: usize,
163) -> Result<Vec<(wellen::SignalRef, wellen::Signal)>> {
164    if signals.is_empty() {
165        return Ok(vec![]);
166    }
167
168    let max_url_length = max_url_length as usize;
169    let base_url = format!("{server}/{file_index}/get_signals");
170    let base_len = base_url.len();
171
172    let mut all_results = Vec::with_capacity(signals.len());
173    let mut current_batch = Vec::new();
174    let mut current_url_len = base_len;
175
176    for signal in signals {
177        // Each signal adds: "/" + digits
178        let signal_len = signal_url_len(signal.index());
179
180        // Check if adding this signal would exceed the limit
181        if current_url_len + signal_len > max_url_length && !current_batch.is_empty() {
182            info!(
183                "Fetching batch of {} signals due to URL length limit",
184                current_batch.len()
185            );
186            // Fetch current batch
187            let batch_results = get_signals_batch(&base_url, &current_batch).await?;
188            all_results.extend(batch_results);
189
190            // Start new batch
191            current_batch.clear();
192            current_url_len = base_len;
193        }
194
195        current_batch.push(*signal);
196        current_url_len += signal_len;
197    }
198
199    // Fetch remaining batch
200    if !current_batch.is_empty() {
201        let batch_results = get_signals_batch(&base_url, &current_batch).await?;
202        all_results.extend(batch_results);
203    }
204
205    Ok(all_results)
206}
207
208// Helper to format signal URL
209// Extracted for testing
210#[inline]
211fn format_signal_url(base_url: &str, signals: &[wellen::SignalRef]) -> String {
212    let mut url = base_url.to_string();
213    for signal in signals {
214        write!(url, "/{}", signal.index()).unwrap();
215    }
216    url
217}
218
219async fn get_signals_batch(
220    base_url: &str,
221    signals: &[wellen::SignalRef],
222) -> Result<Vec<(wellen::SignalRef, wellen::Signal)>> {
223    let client = get_client();
224    let url = format_signal_url(base_url, signals);
225
226    let response = client.get(url).send().await?;
227    check_response(base_url, &response)?;
228    let data = response.bytes().await?;
229    let mut reader = std::io::Cursor::new(data);
230    let num_ids: u64 = leb128::read::unsigned(&mut reader)?;
231    if num_ids > signals.len() as u64 {
232        bail!(
233            "Too many signals in response: {num_ids}, expected {}",
234            signals.len()
235        );
236    }
237    if num_ids == 0 {
238        return Ok(vec![]);
239    }
240
241    let opts = BINCODE_OPTIONS.allow_trailing_bytes();
242    let mut out = Vec::with_capacity(num_ids as usize);
243    for _ in 0..(num_ids - 1) {
244        let compressed: wellen::CompressedSignal = opts.deserialize_from(&mut reader)?;
245        let signal = compressed.uncompress();
246        out.push((signal.signal_ref(), signal));
247    }
248    // for the final signal, we expect to consume all bytes
249    let compressed: wellen::CompressedSignal = BINCODE_OPTIONS.deserialize_from(&mut reader)?;
250    let signal = compressed.uncompress();
251    out.push((signal.signal_ref(), signal));
252    Ok(out)
253}
254
255pub fn get_hierarchy_from_server(
256    sender: Sender<Message>,
257    server: String,
258    load_options: LoadOptions,
259    file_index: usize,
260) {
261    let start = web_time::Instant::now();
262    let source = WaveSource::Url(server.clone());
263
264    perform_async_work(async move {
265        let res = get_hierarchy(server.clone(), file_index)
266            .await
267            .map_err(|e| anyhow!("{e:?}"))
268            .with_context(|| format!("Failed to retrieve hierarchy from remote server {server}"));
269
270        let msg = match res {
271            Ok(h) => {
272                let header =
273                    HeaderResult::Remote(Arc::new(h.hierarchy), h.file_format, server, file_index);
274                Message::WaveHeaderLoaded(start, source, load_options, header)
275            }
276            Err(e) => Message::Error(e),
277        };
278        if let Err(e) = sender.send(msg) {
279            error!("Failed to send message: {e}");
280        }
281    });
282}
283
284pub fn get_time_table_from_server(sender: Sender<Message>, server: String, file_index: usize) {
285    let start = web_time::Instant::now();
286    let source = WaveSource::Url(server.clone());
287
288    perform_async_work(async move {
289        let res = get_time_table(server.clone(), file_index)
290            .await
291            .map_err(|e| anyhow!("{e:?}"))
292            .with_context(|| format!("Failed to retrieve time table from remote server {server}"));
293
294        let msg = match res {
295            Ok(table) => Message::WaveBodyLoaded(start, source, BodyResult::Remote(table, server)),
296            Err(e) => Message::Error(e),
297        };
298        if let Err(e) = sender.send(msg) {
299            error!("Failed to send message: {e}");
300        }
301    });
302}
303
304pub fn get_server_status(sender: Sender<Message>, server: String, delay_ms: u64) {
305    let start = web_time::Instant::now();
306    perform_async_work(async move {
307        sleep_ms(delay_ms).await;
308        let res = get_status(server.clone())
309            .await
310            .map_err(|e| anyhow!("{e:?}"))
311            .with_context(|| format!("Failed to retrieve status from remote server {server}"));
312
313        let msg = match res {
314            Ok(status) => Message::SetSurverStatus(start, server, status),
315            Err(e) => Message::Error(e),
316        };
317        if let Err(e) = sender.send(msg) {
318            error!("Failed to send message: {e}");
319        }
320    });
321}
322
323pub fn server_reload(
324    sender: Sender<Message>,
325    server: String,
326    load_options: LoadOptions,
327    file_index: usize,
328) {
329    let start = web_time::Instant::now();
330    perform_async_work(async move {
331        let res = reload(server.clone(), file_index).await;
332        let mut request_hierarchy = false;
333
334        let msg = match res {
335            Ok(status) => {
336                request_hierarchy = true;
337                Message::SetSurverStatus(start, server.clone(), status)
338            }
339            Err(crate::remote::ReloadError::FileUnchanged) => Message::StopProgressTracker,
340            Err(e) => {
341                let err = anyhow!("{e:?}");
342                Message::Error(err)
343            }
344        };
345        if let Err(e) = sender.send(msg) {
346            error!("Failed to send message: {e}");
347        }
348        if request_hierarchy {
349            get_hierarchy_from_server(sender, server, load_options, file_index);
350        }
351    });
352}
353
354mod tests {
355    #[test]
356    fn test_signal_url_length_calculation() {
357        use crate::remote::client::signal_url_len;
358        // Test edge cases for digit calculation
359        assert_eq!(signal_url_len(0), 2); // "/0" -> 2 chars
360        assert_eq!(signal_url_len(1), 2); // "/1" -> 2 chars
361        assert_eq!(signal_url_len(9), 2); // "/9" -> 2 chars
362        assert_eq!(signal_url_len(10), 3); // "/10" -> 3 chars
363        assert_eq!(signal_url_len(99), 3); // "/99" -> 3 chars
364        assert_eq!(signal_url_len(100), 4); // "/100" -> 4 chars
365        assert_eq!(signal_url_len(999), 4); // "/999" -> 4 chars
366        assert_eq!(signal_url_len(1000), 5); // "/1000" -> 5 chars
367        assert_eq!(signal_url_len(65535), 6); // "/65535" -> 6 chars
368    }
369
370    #[test]
371    fn test_empty_signals_returns_empty() {
372        use crate::remote::get_signals;
373        // Create a mock async runtime for testing
374        let rt = tokio::runtime::Runtime::new().unwrap();
375        rt.block_on(async {
376            let signals: Vec<wellen::SignalRef> = vec![];
377            let result = get_signals("http://localhost:8080".to_string(), &signals, 1000, 0).await;
378
379            // Should return Ok with empty vec without making any network calls
380            assert!(result.is_ok());
381            assert_eq!(result.unwrap().len(), 0);
382        });
383    }
384
385    #[test]
386    fn test_boundary_signal_indices() {
387        use crate::remote::client::signal_url_len;
388        // Test that we handle boundary cases correctly
389        let boundary_indices = vec![0, 1, 9, 10, 99, 100, 999, 1000, 9999, 10000];
390
391        for idx in boundary_indices {
392            let sig_ref = wellen::SignalRef::from_index(idx);
393            let len = signal_url_len(sig_ref.unwrap().index());
394
395            // Verify the calculated length matches actual string length
396            let actual = format!("/{idx}");
397            assert_eq!(
398                len,
399                actual.len(),
400                "URL length calculation mismatch for index {}: expected {}, got {}",
401                idx,
402                actual.len(),
403                len
404            );
405        }
406    }
407
408    #[test]
409    fn test_url_construction_format() {
410        use crate::remote::client::format_signal_url;
411        // Verify URL format matches expected pattern
412        let base_url = "http://localhost:8080/get_signals";
413        let signals: Vec<wellen::SignalRef> = vec![
414            wellen::SignalRef::from_index(1),
415            wellen::SignalRef::from_index(42),
416            wellen::SignalRef::from_index(999),
417        ]
418        .into_iter()
419        .flatten()
420        .collect();
421
422        let url = format_signal_url(base_url, &signals);
423
424        assert_eq!(url, "http://localhost:8080/get_signals/1/42/999");
425    }
426}