1use std::fmt::Write as _;
2use std::sync::Arc;
3use std::sync::OnceLock;
4use std::sync::mpsc::Sender;
5
6use bincode::Options;
7use eyre::{Result, WrapErr as _, anyhow, bail, eyre};
8use reqwest::StatusCode;
9use thiserror::Error;
10use tracing::{info, warn};
11use wellen::CompressedTimeTable;
12
13use surver::{
14 BINCODE_OPTIONS, HTTP_SERVER_KEY, HTTP_SERVER_VALUE_SURFER, SURFER_VERSION, SurverStatus,
15 WELLEN_VERSION, X_SURFER_VERSION, X_WELLEN_VERSION,
16};
17
18use super::HierarchyResponse;
19use crate::async_util::{perform_async_work, sleep_ms};
20use crate::channels::checked_send;
21use crate::message::Message;
22use crate::wave_source::{LoadOptions, WaveSource};
23use crate::wellen::{BodyResult, HeaderResult};
24
25fn get_client() -> &'static reqwest::Client {
27 static CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
28 CLIENT.get_or_init(reqwest::Client::new)
29}
30
31#[derive(Debug, Error)]
32pub enum ReloadError {
33 #[error("File unchanged since last reload")]
34 FileUnchanged,
35 #[error("Unexpected response code: {0}")]
36 UnexpectedStatus(StatusCode),
37 #[error("Network error: {0}")]
38 Network(#[from] reqwest::Error),
39 #[error("Parse error: {0}")]
40 Parse(#[from] serde_json::Error),
41 #[error("Response validation error: {0}")]
42 Validation(#[from] eyre::Report),
43}
44
45fn check_response(server_url: &str, response: &reqwest::Response) -> Result<()> {
46 let server = response
47 .headers()
48 .get(HTTP_SERVER_KEY)
49 .ok_or(eyre!("no server header"))?
50 .to_str()?;
51 if server != HTTP_SERVER_VALUE_SURFER {
52 bail!("Unexpected server {server} from {server_url}");
53 }
54 let surfer_version = response
55 .headers()
56 .get(X_SURFER_VERSION)
57 .ok_or(eyre!("no surfer version header"))?
58 .to_str()?;
59 if surfer_version != SURFER_VERSION {
60 info!(
62 "Surfer version on the server: {surfer_version} does not match client version {SURFER_VERSION}"
63 );
64 }
65 let wellen_version = response
66 .headers()
67 .get(X_WELLEN_VERSION)
68 .ok_or(eyre!("no wellen version header"))?
69 .to_str()?;
70 if wellen_version != WELLEN_VERSION {
71 bail!(
72 "Version incompatibility! The server uses wellen {wellen_version}, our client uses wellen {WELLEN_VERSION}"
73 );
74 }
75 Ok(())
76}
77
78async fn get_status(server: String) -> Result<SurverStatus> {
79 let client = get_client();
80 let response = client.get(format!("{server}/get_status")).send().await?;
81 check_response(&server, &response)?;
82 let body = response.text().await?;
83 let status = serde_json::from_str::<SurverStatus>(&body)?;
84 Ok(status)
85}
86
87async fn reload(
88 server: String,
89 file_index: usize,
90) -> std::result::Result<SurverStatus, ReloadError> {
91 let client = get_client();
92 let response = client
93 .get(format!("{server}/{file_index}/reload"))
94 .send()
95 .await?;
96 check_response(&server, &response)?;
97 let status_code = response.status();
98 let body = response.text().await?;
99 match status_code {
100 StatusCode::NOT_MODIFIED => {
101 info!("File unchanged, no reload needed");
102 Err(ReloadError::FileUnchanged)
103 }
104 StatusCode::ACCEPTED => {
105 info!("File reloaded at server");
106 let status = serde_json::from_str::<SurverStatus>(&body)?;
107 Ok(status)
108 }
109 code => {
110 warn!("Unexpected response code: {code}");
111 Err(ReloadError::UnexpectedStatus(code))
112 }
113 }
114}
115
116async fn get_hierarchy(server: String, file_index: usize) -> Result<HierarchyResponse> {
117 let client = get_client();
118 let response = client
119 .get(format!("{server}/{file_index}/get_hierarchy"))
120 .send()
121 .await?;
122 check_response(&server, &response)?;
123 let compressed = response.bytes().await?;
124 let raw = lz4_flex::decompress_size_prepended(&compressed)?;
125 let mut reader = std::io::Cursor::new(raw);
126 let opts = BINCODE_OPTIONS.allow_trailing_bytes();
128 let file_format: wellen::FileFormat = opts.deserialize_from(&mut reader)?;
129 let hierarchy: wellen::Hierarchy = BINCODE_OPTIONS.deserialize_from(&mut reader)?;
131 Ok(HierarchyResponse {
132 hierarchy,
133 file_format,
134 })
135}
136
137async fn get_time_table(server: String, file_index: usize) -> Result<Vec<wellen::Time>> {
138 let client = get_client();
139 let response = client
140 .get(format!("{server}/{file_index}/get_time_table"))
141 .send()
142 .await?;
143 check_response(&server, &response)?;
144 let compressed_data = response.bytes().await?;
145 let compressed: CompressedTimeTable = BINCODE_OPTIONS.deserialize(&compressed_data)?;
146 let table = compressed.uncompress();
147 Ok(table)
148}
149
150#[inline]
154fn signal_url_len(index: usize) -> usize {
155 index.checked_ilog10().unwrap_or(0) as usize + 2 }
157
158pub async fn get_signals(
159 server: String,
160 signals: &[wellen::SignalRef],
161 max_url_length: u16,
162 file_index: usize,
163) -> Result<Vec<(wellen::SignalRef, wellen::Signal)>> {
164 if signals.is_empty() {
165 return Ok(vec![]);
166 }
167
168 let max_url_length = max_url_length as usize;
169 let base_url = format!("{server}/{file_index}/get_signals");
170 let base_len = base_url.len();
171
172 let mut all_results = Vec::with_capacity(signals.len());
173 let mut current_batch = Vec::new();
174 let mut current_url_len = base_len;
175
176 for signal in signals {
177 let signal_len = signal_url_len(signal.index());
179
180 if current_url_len + signal_len > max_url_length && !current_batch.is_empty() {
182 info!(
183 "Fetching batch of {} signals due to URL length limit",
184 current_batch.len()
185 );
186 let batch_results = get_signals_batch(&base_url, ¤t_batch).await?;
188 all_results.extend(batch_results);
189
190 current_batch.clear();
192 current_url_len = base_len;
193 }
194
195 current_batch.push(*signal);
196 current_url_len += signal_len;
197 }
198
199 if !current_batch.is_empty() {
201 let batch_results = get_signals_batch(&base_url, ¤t_batch).await?;
202 all_results.extend(batch_results);
203 }
204
205 Ok(all_results)
206}
207
208#[inline]
211fn format_signal_url(base_url: &str, signals: &[wellen::SignalRef]) -> String {
212 let mut url = base_url.to_string();
213 for signal in signals {
214 write!(url, "/{}", signal.index()).unwrap();
215 }
216 url
217}
218
219async fn get_signals_batch(
220 base_url: &str,
221 signals: &[wellen::SignalRef],
222) -> Result<Vec<(wellen::SignalRef, wellen::Signal)>> {
223 let client = get_client();
224 let url = format_signal_url(base_url, signals);
225
226 let response = client.get(url).send().await?;
227 check_response(base_url, &response)?;
228 let data = response.bytes().await?;
229 let mut reader = std::io::Cursor::new(data);
230 let num_ids: u64 = leb128::read::unsigned(&mut reader)?;
231 if num_ids > signals.len() as u64 {
232 bail!(
233 "Too many signals in response: {num_ids}, expected {}",
234 signals.len()
235 );
236 }
237 if num_ids == 0 {
238 return Ok(vec![]);
239 }
240
241 let opts = BINCODE_OPTIONS.allow_trailing_bytes();
242 let mut out = Vec::with_capacity(num_ids as usize);
243 for _ in 0..(num_ids - 1) {
244 let compressed: wellen::CompressedSignal = opts.deserialize_from(&mut reader)?;
245 let signal = compressed.uncompress();
246 out.push((signal.signal_ref(), signal));
247 }
248 let compressed: wellen::CompressedSignal = BINCODE_OPTIONS.deserialize_from(&mut reader)?;
250 let signal = compressed.uncompress();
251 out.push((signal.signal_ref(), signal));
252 Ok(out)
253}
254
255pub fn get_hierarchy_from_server(
256 sender: Sender<Message>,
257 server: String,
258 load_options: LoadOptions,
259 file_index: usize,
260) {
261 let start = web_time::Instant::now();
262 let source = WaveSource::Url(server.clone());
263
264 perform_async_work(async move {
265 let res = get_hierarchy(server.clone(), file_index)
266 .await
267 .map_err(|e| anyhow!("{e:?}"))
268 .with_context(|| format!("Failed to retrieve hierarchy from remote server {server}"));
269
270 let msg = match res {
271 Ok(h) => {
272 let header =
273 HeaderResult::Remote(Arc::new(h.hierarchy), h.file_format, server, file_index);
274 Message::WaveHeaderLoaded(start, source, load_options, header)
275 }
276 Err(e) => Message::Error(e),
277 };
278 checked_send(&sender, msg);
279 });
280}
281
282pub fn get_time_table_from_server(sender: Sender<Message>, server: String, file_index: usize) {
283 let start = web_time::Instant::now();
284 let source = WaveSource::Url(server.clone());
285
286 perform_async_work(async move {
287 let res = get_time_table(server.clone(), file_index)
288 .await
289 .map_err(|e| anyhow!("{e:?}"))
290 .with_context(|| format!("Failed to retrieve time table from remote server {server}"));
291
292 let msg = match res {
293 Ok(table) => Message::WaveBodyLoaded(start, source, BodyResult::Remote(table, server)),
294 Err(e) => Message::Error(e),
295 };
296 checked_send(&sender, msg);
297 });
298}
299
300pub fn get_server_status(sender: Sender<Message>, server: String, delay_ms: u64) {
301 let start = web_time::Instant::now();
302 perform_async_work(async move {
303 sleep_ms(delay_ms).await;
304 let res = get_status(server.clone())
305 .await
306 .map_err(|e| anyhow!("{e:?}"))
307 .with_context(|| format!("Failed to retrieve status from remote server {server}"));
308
309 let msg = match res {
310 Ok(status) => Message::SetSurverStatus(start, server, status),
311 Err(e) => Message::Error(e),
312 };
313 checked_send(&sender, msg);
314 });
315}
316
317pub fn server_reload(
318 sender: Sender<Message>,
319 server: String,
320 load_options: LoadOptions,
321 file_index: usize,
322) {
323 let start = web_time::Instant::now();
324 perform_async_work(async move {
325 let res = reload(server.clone(), file_index).await;
326 let mut request_hierarchy = false;
327
328 let msg = match res {
329 Ok(status) => {
330 request_hierarchy = true;
331 Message::SetSurverStatus(start, server.clone(), status)
332 }
333 Err(crate::remote::ReloadError::FileUnchanged) => Message::StopProgressTracker,
334 Err(e) => {
335 let err = anyhow!("{e:?}");
336 Message::Error(err)
337 }
338 };
339 checked_send(&sender, msg);
340 if request_hierarchy {
341 get_hierarchy_from_server(sender, server, load_options, file_index);
342 }
343 });
344}
345
346mod tests {
347 #[test]
348 fn test_signal_url_length_calculation() {
349 use crate::remote::client::signal_url_len;
350 assert_eq!(signal_url_len(0), 2); assert_eq!(signal_url_len(1), 2); assert_eq!(signal_url_len(9), 2); assert_eq!(signal_url_len(10), 3); assert_eq!(signal_url_len(99), 3); assert_eq!(signal_url_len(100), 4); assert_eq!(signal_url_len(999), 4); assert_eq!(signal_url_len(1000), 5); assert_eq!(signal_url_len(65535), 6); }
361
362 #[test]
363 fn test_empty_signals_returns_empty() {
364 use crate::remote::get_signals;
365 let rt = tokio::runtime::Runtime::new().unwrap();
367 rt.block_on(async {
368 let signals: Vec<wellen::SignalRef> = vec![];
369 let result = get_signals("http://localhost:8080".to_string(), &signals, 1000, 0).await;
370
371 assert!(result.is_ok());
373 assert_eq!(result.unwrap().len(), 0);
374 });
375 }
376
377 #[test]
378 fn test_boundary_signal_indices() {
379 use crate::remote::client::signal_url_len;
380 let boundary_indices = vec![0, 1, 9, 10, 99, 100, 999, 1000, 9999, 10000];
382
383 for idx in boundary_indices {
384 let sig_ref = wellen::SignalRef::from_index(idx);
385 let len = signal_url_len(sig_ref.unwrap().index());
386
387 let actual = format!("/{idx}");
389 assert_eq!(
390 len,
391 actual.len(),
392 "URL length calculation mismatch for index {}: expected {}, got {}",
393 idx,
394 actual.len(),
395 len
396 );
397 }
398 }
399
400 #[test]
401 fn test_url_construction_format() {
402 use crate::remote::client::format_signal_url;
403 let base_url = "http://localhost:8080/get_signals";
405 let signals: Vec<wellen::SignalRef> = vec![
406 wellen::SignalRef::from_index(1),
407 wellen::SignalRef::from_index(42),
408 wellen::SignalRef::from_index(999),
409 ]
410 .into_iter()
411 .flatten()
412 .collect();
413
414 let url = format_signal_url(base_url, &signals);
415
416 assert_eq!(url, "http://localhost:8080/get_signals/1/42/999");
417 }
418}