Skip to main content

surver/
server.rs

1//! Handling of external communication in Surver.
2use bincode::Options;
3use eyre::{Context, Result, anyhow, bail};
4use http_body_util::Full;
5use hyper::body::Bytes;
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper::{Request, Response, StatusCode};
9use hyper_util::rt::TokioIo;
10use std::collections::HashMap;
11use std::fs;
12use std::iter::repeat_with;
13use std::net::SocketAddr;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::mpsc::Sender;
16use std::sync::{Arc, RwLock};
17use std::time::{Instant, SystemTime};
18use tokio::net::TcpListener;
19use tokio::sync::Notify;
20use tracing::{error, info, warn};
21use wellen::{
22    CompressedSignal, CompressedTimeTable, FileFormat, Hierarchy, Signal, SignalRef, Time, viewers,
23};
24
25use crate::{
26    BINCODE_OPTIONS, HTTP_SERVER_KEY, HTTP_SERVER_VALUE_SURFER, SURFER_VERSION, SurverFileInfo,
27    SurverStatus, WELLEN_SURFER_DEFAULT_OPTIONS, WELLEN_VERSION, X_SURFER_VERSION,
28    X_WELLEN_VERSION, modification_time_string,
29};
30
31struct ReadOnly {
32    url: String,
33    token: String,
34}
35
36struct FileInfo {
37    filename: String,
38    hierarchy: Hierarchy,
39    file_format: FileFormat,
40    header_len: u64,
41    body_len: u64,
42    body_progress: Arc<AtomicU64>,
43    notify: Arc<Notify>,
44    timetable: Vec<Time>,
45    signals: HashMap<SignalRef, Signal>,
46    reloading: bool,
47    last_reload_ok: bool,
48    last_reload_time: Option<Instant>,
49    last_modification_time: Option<SystemTime>,
50}
51
52#[derive(Default)]
53struct SurverState {
54    file_infos: Vec<FileInfo>,
55}
56
57impl FileInfo {
58    fn modification_time_string(&self) -> String {
59        modification_time_string(self.last_modification_time)
60    }
61
62    fn reload_time_string(&self) -> String {
63        if let Some(time) = self.last_reload_time {
64            return format!("{:?} ago", time.elapsed());
65        }
66        "never".to_string()
67    }
68
69    pub fn html_table_line(&self) -> String {
70        let bytes_loaded = self.body_progress.load(Ordering::SeqCst);
71
72        let progress = if bytes_loaded == self.body_len {
73            format!(
74                "{} loaded",
75                bytesize::ByteSize::b(self.body_len + self.header_len)
76            )
77        } else {
78            format!(
79                "{} / {}",
80                bytesize::ByteSize::b(bytes_loaded + self.header_len),
81                bytesize::ByteSize::b(self.body_len + self.header_len)
82            )
83        };
84
85        format!(
86            "<tr><td>{}</td><td>{}</td><td>{}</td><td>{}</td></tr>",
87            self.filename,
88            progress,
89            self.modification_time_string(),
90            self.reload_time_string()
91        )
92    }
93}
94
95impl From<&FileInfo> for SurverFileInfo {
96    fn from(file_info: &FileInfo) -> Self {
97        Self {
98            bytes: file_info.body_len + file_info.header_len,
99            bytes_loaded: file_info.body_progress.load(Ordering::SeqCst) + file_info.header_len,
100            filename: file_info.filename.clone(),
101            format: file_info.file_format,
102            reloading: file_info.reloading,
103            last_load_ok: file_info.last_reload_ok,
104            last_modification_time: file_info.last_modification_time,
105        }
106    }
107}
108enum LoaderMessage {
109    SignalRequest(SignalRequest),
110    Reload,
111}
112
113type SignalRequest = Vec<SignalRef>;
114
115fn get_info_page(shared: &Arc<ReadOnly>, state: &Arc<RwLock<SurverState>>) -> String {
116    let state_guard = state.read().expect("State lock poisoned in get_info_page");
117    let html_table_content = state_guard
118        .file_infos
119        .iter()
120        .map(FileInfo::html_table_line)
121        .collect::<Vec<_>>()
122        .join("\n");
123    drop(state_guard);
124
125    format!(
126        r#"
127    <!DOCTYPE html><html lang="en">
128    <head>
129    <link rel="icon" href="favicon.ico" sizes="any">
130    <title>Surver - Surfer Remote Server</title>
131    </head>
132    <body>
133    <h1>Surver - Surfer Remote Server</h1>
134    <b>To connect, run:</b> <code>surfer {}</code><br>
135    <b>Wellen version:</b> {WELLEN_VERSION}<br>
136    <b>Surfer version:</b> {SURFER_VERSION}<br>
137    <table border="1" cellpadding="5" cellspacing="0">
138    <tr><th>Filename</th><th>Load progress</th><th>File modification time</th><th>(Re)load time</th></tr>
139    {}
140    </table>
141    </body></html>
142    "#,
143        shared.url, html_table_content
144    )
145}
146
147fn get_hierarchy(state: &Arc<RwLock<SurverState>>, file_index: usize) -> Result<Vec<u8>> {
148    let state_guard = state.read().expect("State lock poisoned in get_hierarchy");
149    let file_info = &state_guard.file_infos[file_index];
150    let mut raw = BINCODE_OPTIONS.serialize(&file_info.file_format)?;
151    let mut raw2 = BINCODE_OPTIONS.serialize(&file_info.hierarchy)?;
152    drop(state_guard);
153    raw.append(&mut raw2);
154    let compressed = lz4_flex::compress_prepend_size(&raw);
155    info!(
156        "Sending hierarchy. {} raw, {} compressed.",
157        bytesize::ByteSize::b(raw.len() as u64),
158        bytesize::ByteSize::b(compressed.len() as u64)
159    );
160    Ok(compressed)
161}
162
163async fn get_timetable(state: &Arc<RwLock<SurverState>>, file_index: usize) -> Result<Vec<u8>> {
164    // poll to see when the time table is available
165    #[allow(unused_assignments)]
166    let mut table = vec![];
167    loop {
168        {
169            let state = state.read().unwrap();
170            if !state.file_infos[file_index].timetable.is_empty() {
171                table.clone_from(&state.file_infos[file_index].timetable);
172                break;
173            }
174        }
175        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
176    }
177    let raw_size = table.len() * std::mem::size_of::<Time>();
178    let compressed = BINCODE_OPTIONS.serialize(&CompressedTimeTable::compress(&table))?;
179    info!(
180        "Sending timetable. {} raw, {} compressed.",
181        bytesize::ByteSize::b(raw_size as u64),
182        bytesize::ByteSize::b(compressed.len() as u64)
183    );
184    Ok(compressed)
185}
186
187fn get_status(state: &Arc<RwLock<SurverState>>) -> Result<Vec<u8>> {
188    let state_guard = state.read().expect("State lock poisoned in get_status");
189    let file_infos = state_guard
190        .file_infos
191        .iter()
192        .map(SurverFileInfo::from)
193        .collect::<Vec<_>>();
194    drop(state_guard);
195    let status = SurverStatus {
196        wellen_version: WELLEN_VERSION.to_string(),
197        surfer_version: SURFER_VERSION.to_string(),
198        file_infos,
199    };
200    Ok(serde_json::to_vec(&status)?)
201}
202
203async fn get_signals(
204    state: &Arc<RwLock<SurverState>>,
205    file_index: usize,
206    txs: &[Sender<LoaderMessage>],
207    id_strings: &[&str],
208) -> Result<Vec<u8>> {
209    let ids = id_strings
210        .iter()
211        .map(|id_str| {
212            id_str
213                .parse::<u64>()
214                .map_err(|e| anyhow!("Failed to parse signal id `{id_str}`: {e:#}"))
215                .and_then(|index| {
216                    SignalRef::from_index(index as usize)
217                        .ok_or_else(|| anyhow!("Invalid signal index: {}", index))
218                })
219        })
220        .collect::<Result<Vec<SignalRef>>>()?;
221
222    if ids.is_empty() {
223        return Ok(vec![]);
224    }
225    let num_ids = ids.len();
226
227    // send request to background thread
228    txs[file_index].send(LoaderMessage::SignalRequest(ids.clone()))?;
229
230    let notify = {
231        let state_guard = state.read().expect("State lock poisoned in get_signals");
232        state_guard.file_infos[file_index].notify.clone()
233    };
234
235    // Wait for all signals to be loaded
236    let mut data = vec![];
237    leb128::write::unsigned(&mut data, num_ids as u64)?;
238    let mut raw_size = 0;
239    loop {
240        {
241            let state_guard = state.read().expect("State lock poisoned in get_signals");
242            if ids
243                .iter()
244                .all(|id| state_guard.file_infos[file_index].signals.contains_key(id))
245            {
246                for id in ids {
247                    let signal = &state_guard.file_infos[file_index].signals[&id];
248                    raw_size += BINCODE_OPTIONS.serialize(signal)?.len();
249                    let comp = CompressedSignal::compress(signal);
250                    data.append(&mut BINCODE_OPTIONS.serialize(&comp)?);
251                }
252                break;
253            }
254        };
255        // Wait for notification that signals have been loaded
256        notify.notified().await;
257    }
258    info!(
259        "Sending {} signals. {} raw, {} compressed.",
260        num_ids,
261        bytesize::ByteSize::b(raw_size as u64),
262        bytesize::ByteSize::b(data.len() as u64)
263    );
264    Ok(data)
265}
266
267const CONTENT_TYPE: &str = "Content-Type";
268const JSON_MIME: &str = "application/json";
269const OCTET_MIME: &str = "application/octet-stream";
270const HTML_MIME: &str = "text/html; charset=utf-8";
271
272trait DefaultHeader {
273    fn default_header(self) -> Self;
274}
275
276impl DefaultHeader for hyper::http::response::Builder {
277    fn default_header(self) -> Self {
278        self.header(HTTP_SERVER_KEY, HTTP_SERVER_VALUE_SURFER)
279            .header(X_WELLEN_VERSION, WELLEN_VERSION)
280            .header(X_SURFER_VERSION, SURFER_VERSION)
281            .header("Cache-Control", "no-cache")
282    }
283}
284
285fn build_response(
286    status: StatusCode,
287    content_type: &str,
288    body: Vec<u8>,
289) -> Result<Response<Full<Bytes>>> {
290    Ok(Response::builder()
291        .status(status)
292        .header(CONTENT_TYPE, content_type)
293        .default_header()
294        .body(Full::from(body))?)
295}
296
297fn not_found_response(message: &[u8]) -> Result<Response<Full<Bytes>>> {
298    build_response(StatusCode::NOT_FOUND, OCTET_MIME, message.to_vec())
299}
300
301async fn handle_cmd(
302    state: &Arc<RwLock<SurverState>>,
303    txs: &[Sender<LoaderMessage>],
304    cmd: &str,
305    file_index: Option<usize>,
306    args: &[&str],
307) -> Result<Response<Full<Bytes>>> {
308    // Check file index is valid if provided
309    if let Some(file_index) = file_index {
310        let state_guard = state.read().expect("State lock poisoned in handle_cmd");
311        if file_index >= state_guard.file_infos.len() {
312            drop(state_guard);
313            return not_found_response(b"Invalid file index");
314        }
315    }
316    match (file_index, cmd, args) {
317        (_, "get_status", []) => {
318            let body = get_status(state)?;
319            build_response(StatusCode::OK, JSON_MIME, body)
320        }
321        (Some(file_index), "get_hierarchy", []) => {
322            let body = get_hierarchy(state, file_index)?;
323            build_response(StatusCode::OK, OCTET_MIME, body)
324        }
325        (Some(file_index), "get_time_table", []) => {
326            let body = get_timetable(state, file_index).await?;
327            build_response(StatusCode::OK, OCTET_MIME, body)
328        }
329        (Some(file_index), "get_signals", id_strings) => {
330            let body = get_signals(state, file_index, txs, id_strings).await?;
331            build_response(StatusCode::OK, OCTET_MIME, body)
332        }
333        (Some(file_index), "reload", []) => {
334            let mut state_guard = state.write().expect("State lock poisoned in reload");
335            let file_info = &mut state_guard.file_infos[file_index];
336            // Check file existence, size, and mtime
337            let Ok(meta) = fs::metadata(file_info.filename.clone()) else {
338                drop(state_guard);
339                return not_found_response(b"error: file not found");
340            };
341            let mtime = meta.modified().unwrap_or(std::time::SystemTime::UNIX_EPOCH);
342            // Should probably look at file lengths as well for extra safety, but they are not updated correctly at the moment
343            let unchanged =
344                file_info.last_modification_time == Some(mtime) && file_info.last_reload_ok;
345            if unchanged {
346                drop(state_guard);
347                return build_response(
348                    StatusCode::NOT_MODIFIED,
349                    JSON_MIME,
350                    b"info: file unchanged".to_vec(),
351                );
352            }
353            file_info.last_modification_time = Some(mtime);
354            info!(
355                "File modification time updated to {}",
356                file_info.modification_time_string()
357            );
358            file_info.reloading = true;
359            file_info.last_reload_ok = false;
360            drop(state_guard);
361            info!("Reload requested");
362            txs[file_index].send(LoaderMessage::Reload)?;
363            let body = get_status(state)?;
364            build_response(StatusCode::ACCEPTED, JSON_MIME, body)
365        }
366        _ => {
367            // unknown command or unexpected number of arguments
368            not_found_response(&[])
369        }
370    }
371}
372
373async fn handle(
374    state: Arc<RwLock<SurverState>>,
375    shared: Arc<ReadOnly>,
376    txs: Vec<Sender<LoaderMessage>>,
377    req: Request<hyper::body::Incoming>,
378) -> Result<Response<Full<Bytes>>> {
379    // Check if favicon is requested
380    if req.uri().path() == "/favicon.ico" {
381        let favicon_data = include_bytes!("../assets/favicon.ico");
382        return Ok(Response::builder()
383            .status(StatusCode::OK)
384            .header("Content-Type", "image/x-icon")
385            .header("Cache-Control", "public, max-age=604800")
386            .body(Full::from(&favicon_data[..]))?);
387    }
388    // check to see if the correct token was received
389    let path_parts = req.uri().path().split('/').skip(1).collect::<Vec<_>>();
390
391    // check token
392    if let Some(provided_token) = path_parts.first() {
393        if *provided_token != shared.token {
394            warn!(
395                "Received request with invalid token: {provided_token} != {}\n{:?}",
396                shared.token,
397                req.uri()
398            );
399            return not_found_response(&[]);
400        }
401    } else {
402        // no token
403        warn!("Received request with no token: {:?}", req.uri());
404        return not_found_response(&[]);
405    }
406
407    // Try to parse file index from path_parts[1]
408    let (file_index, cmd_idx) = path_parts
409        .get(1)
410        .and_then(|s| s.parse::<usize>().ok())
411        .map_or((None, 1), |idx| (Some(idx), 2));
412    // check command
413    let response = if let Some(cmd) = path_parts.get(cmd_idx) {
414        handle_cmd(&state, &txs, cmd, file_index, &path_parts[cmd_idx + 1..]).await?
415    } else {
416        // valid token, but no command => return info
417        let body = Full::from(get_info_page(&shared, &state));
418        Response::builder()
419            .status(StatusCode::OK)
420            .header(CONTENT_TYPE, HTML_MIME)
421            .default_header()
422            .body(body)?
423    };
424
425    Ok(response)
426}
427
428const MIN_TOKEN_LEN: usize = 8;
429const RAND_TOKEN_LEN: usize = 24;
430
431pub type ServerStartedFlag = Arc<std::sync::atomic::AtomicBool>;
432
433pub async fn surver_main(
434    port: u16,
435    bind_address: String,
436    token: Option<String>,
437    filenames: &[String],
438    started: Option<ServerStartedFlag>,
439) -> Result<()> {
440    // if no token was provided, we generate one
441    let token = token.unwrap_or_else(|| {
442        // generate a random ASCII token
443        repeat_with(fastrand::alphanumeric)
444            .take(RAND_TOKEN_LEN)
445            .collect()
446    });
447
448    if token.len() < MIN_TOKEN_LEN {
449        bail!("Token `{token}` is too short. At least {MIN_TOKEN_LEN} characters are required!");
450    }
451
452    let state = Arc::new(RwLock::new(SurverState { file_infos: vec![] }));
453
454    let mut txs: Vec<Sender<LoaderMessage>> = Vec::new();
455    // load files
456    for (file_index, filename) in filenames.iter().enumerate() {
457        let start_read_header = web_time::Instant::now();
458        let header_result = wellen::viewers::read_header_from_file(
459            filename.clone(),
460            &WELLEN_SURFER_DEFAULT_OPTIONS,
461        )
462        .map_err(|e| anyhow!("{e:?}"))
463        .with_context(|| format!("Failed to parse wave file: {filename}"))?;
464        info!(
465            "Loaded header of {filename} in {:?}",
466            start_read_header.elapsed()
467        );
468
469        let file_info = FileInfo {
470            filename: filename.clone(),
471            hierarchy: header_result.hierarchy,
472            file_format: header_result.file_format,
473            header_len: 0, // FIXME: get value from wellen
474            body_len: header_result.body_len,
475            body_progress: Arc::new(AtomicU64::new(0)),
476            notify: Arc::new(Notify::new()),
477            timetable: vec![],
478            signals: HashMap::new(),
479            reloading: false,
480            last_reload_ok: true,
481            last_reload_time: None,
482            last_modification_time: None,
483        };
484        {
485            let mut state_guard = state.write().expect("State lock poisoned when adding file");
486            state_guard.file_infos.push(file_info);
487        }
488        // channel to communicate with loader
489        let (tx, rx) = std::sync::mpsc::channel::<LoaderMessage>();
490        txs.push(tx.clone());
491        // start work thread
492        let state_2 = state.clone();
493        std::thread::spawn(move || loader(&state_2, header_result.body, file_index, &rx));
494    }
495    let ip_addr: std::net::IpAddr = bind_address
496        .parse()
497        .with_context(|| format!("Invalid bind address: {bind_address}"))?;
498    let use_localhost = ip_addr.is_loopback();
499    if !use_localhost {
500        warn!(
501            "Server is binding to {bind_address} instead of 127.0.0.1/0:0:0:0:0:0:0:1 (localhost)"
502        );
503        warn!("This may make the server accessible from external networks");
504        warn!("Surver traffic is unencrypted and unauthenticated - use with caution!");
505    }
506
507    // immutable read-only data
508    let addr = SocketAddr::new(ip_addr, port);
509    let url = format!("http://{addr}/{token}");
510    let url_copy = url.clone();
511    let token_copy = token.clone();
512    let shared = Arc::new(ReadOnly { url, token });
513
514    // print out status
515    info!("Starting server on {addr}. To use:");
516    info!("1. Setup an ssh tunnel: -L {port}:localhost:{port}");
517    let hostname = whoami::hostname();
518    if let Ok(hostname) = hostname.as_ref()
519        && hostname != "localhost"
520        && let Ok(username) = whoami::username()
521    {
522        info!(
523            "   The correct command may be: ssh -L {port}:localhost:{port} {username}@{hostname} "
524        );
525    }
526
527    info!("2. Start Surfer: surfer {url_copy} ");
528    if !use_localhost && let Ok(hostname) = hostname {
529        let hosturl = format!("http://{hostname}:{port}/{token_copy}");
530        info!("or, if the host is directly accessible:");
531        info!("1. Start Surfer: surfer {hosturl} ");
532    }
533    // create listener and serve it
534    let listener = TcpListener::bind(&addr).await?;
535
536    // we have started the server
537    if let Some(started) = started {
538        started.store(true, Ordering::SeqCst);
539    }
540
541    // main server loop
542    loop {
543        let (stream, _) = listener.accept().await?;
544        let io = TokioIo::new(stream);
545
546        let state = state.clone();
547        let shared = shared.clone();
548        let txs = txs.clone();
549        tokio::task::spawn(async move {
550            let service =
551                service_fn(move |req| handle(state.clone(), shared.clone(), txs.clone(), req));
552            if let Err(e) = http1::Builder::new().serve_connection(io, service).await {
553                error!("server error: {e}");
554            }
555        });
556    }
557}
558
559/// Thread that loads the body and signals.
560fn loader(
561    state: &Arc<RwLock<SurverState>>,
562    mut body_cont: viewers::ReadBodyContinuation<std::io::BufReader<std::fs::File>>,
563    file_index: usize,
564    rx: &std::sync::mpsc::Receiver<LoaderMessage>,
565) -> Result<()> {
566    loop {
567        // load the body of the file
568        let start_load_body = web_time::Instant::now();
569        let state_guard = state
570            .read()
571            .expect("State lock poisoned in loader before body load");
572        let file_info = &state_guard.file_infos[file_index];
573        let filename = file_info.filename.clone();
574        let body_result = viewers::read_body(
575            body_cont,
576            &file_info.hierarchy,
577            Some(file_info.body_progress.clone()),
578        )
579        .map_err(|e| anyhow!("{e:?}"))
580        .with_context(|| format!("Failed to parse body of wave file: {filename}"))?;
581        drop(state_guard);
582        info!(
583            "Loaded body of {} in {:?}",
584            filename,
585            start_load_body.elapsed()
586        );
587
588        // update state with body results
589        {
590            let mut state_guard = state
591                .write()
592                .expect("State lock poisoned in loader after body load");
593            let file_info = &mut state_guard.file_infos[file_index];
594            file_info.timetable = body_result.time_table;
595            file_info.signals.clear(); // Clear old signals on reload
596            if let Ok(meta) = fs::metadata(&file_info.filename) {
597                file_info.last_modification_time = Some(meta.modified()?);
598                info!(
599                    "File modification time of {} set to {}",
600                    filename,
601                    file_info.modification_time_string()
602                );
603            }
604            file_info.last_reload_time = Some(Instant::now());
605            file_info.reloading = false;
606            file_info.last_reload_ok = true;
607            file_info.notify.notify_waiters();
608        }
609        // source is private, only owned by us
610        let mut source = body_result.source;
611
612        // process requests for signals to be loaded
613        loop {
614            let msg = rx.recv()?;
615
616            match msg {
617                LoaderMessage::SignalRequest(ids) => {
618                    // make sure that we do not load signals that have already been loaded
619                    let mut filtered_ids = {
620                        let state_guard = state
621                            .read()
622                            .expect("State lock poisoned in loader signal request");
623                        ids.iter()
624                            .filter(|id| {
625                                !state_guard.file_infos[file_index].signals.contains_key(id)
626                            })
627                            .copied()
628                            .collect::<Vec<_>>()
629                    };
630
631                    // check if there is anything left to do
632                    if filtered_ids.is_empty() {
633                        continue;
634                    }
635
636                    // load signals without holding the lock
637                    filtered_ids.sort();
638                    filtered_ids.dedup();
639                    let result = {
640                        let state_guard = state
641                            .read()
642                            .expect("State lock poisoned in loader signal request");
643                        source.load_signals(
644                            &filtered_ids,
645                            &state_guard.file_infos[file_index].hierarchy,
646                            true,
647                        )
648                    };
649
650                    // store signals
651                    {
652                        let mut state_guard = state
653                            .write()
654                            .expect("State lock poisoned in loader when storing signals");
655                        for (id, signal) in result {
656                            state_guard.file_infos[file_index]
657                                .signals
658                                .insert(id, signal);
659                        }
660                        state_guard.file_infos[file_index].notify.notify_waiters();
661                    }
662                }
663                LoaderMessage::Reload => {
664                    let state_guard = state
665                        .read()
666                        .expect("State lock poisoned in loader before reload");
667                    info!(
668                        "Reloading waveform file: {}",
669                        state_guard.file_infos[file_index].filename
670                    );
671                    // Reset progress counter
672                    state_guard.file_infos[file_index]
673                        .body_progress
674                        .store(0, Ordering::SeqCst);
675
676                    // Re-read header to get new body continuation
677                    let header_result = wellen::viewers::read_header_from_file(
678                        state_guard.file_infos[file_index].filename.clone(),
679                        &WELLEN_SURFER_DEFAULT_OPTIONS,
680                    )
681                    .map_err(|e| anyhow!("{e:?}"))
682                    .with_context(|| {
683                        format!(
684                            "Failed to reload wave file: {}",
685                            state_guard.file_infos[file_index].filename
686                        )
687                    })?;
688
689                    body_cont = header_result.body;
690                    break; // Break inner loop to reload the body
691                }
692            }
693        }
694    }
695}