surver/
server.rs

1//! Handling of external communication in Surver.
2use bincode::Options;
3use color_eyre::eyre::{anyhow, bail, Context};
4use color_eyre::Result;
5use http_body_util::Full;
6use hyper::body::Bytes;
7use hyper::server::conn::http1;
8use hyper::service::service_fn;
9use hyper::{Request, Response, StatusCode};
10use hyper_util::rt::TokioIo;
11use log::{error, info, warn};
12use std::collections::HashMap;
13use std::io::{BufRead, Seek};
14use std::iter::repeat_with;
15use std::net::SocketAddr;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::mpsc::Sender;
18use std::sync::{Arc, RwLock};
19use tokio::net::TcpListener;
20use wellen::{
21    viewers, CompressedSignal, CompressedTimeTable, FileFormat, Hierarchy, Signal, SignalRef, Time,
22};
23
24use crate::{
25    Status, BINCODE_OPTIONS, HTTP_SERVER_KEY, HTTP_SERVER_VALUE_SURFER, SURFER_VERSION,
26    WELLEN_SURFER_DEFAULT_OPTIONS, WELLEN_VERSION, X_SURFER_VERSION, X_WELLEN_VERSION,
27};
28
29struct ReadOnly {
30    url: String,
31    token: String,
32    filename: String,
33    hierarchy: Hierarchy,
34    file_format: FileFormat,
35    header_len: u64,
36    body_len: u64,
37    body_progress: Arc<AtomicU64>,
38}
39
40#[derive(Default)]
41struct State {
42    timetable: Vec<Time>,
43    signals: HashMap<SignalRef, Signal>,
44}
45
46type SignalRequest = Vec<SignalRef>;
47
48fn get_info_page(shared: Arc<ReadOnly>) -> String {
49    let bytes_loaded = shared.body_progress.load(Ordering::SeqCst);
50
51    let progress = if bytes_loaded == shared.body_len {
52        format!(
53            "{} loaded",
54            bytesize::ByteSize::b(shared.body_len + shared.header_len)
55        )
56    } else {
57        format!(
58            "{} / {}",
59            bytesize::ByteSize::b(bytes_loaded + shared.header_len),
60            bytesize::ByteSize::b(shared.body_len + shared.header_len)
61        )
62    };
63
64    format!(
65        r#"
66    <!DOCTYPE html><html lang="en">
67    <head><title>Surver - Surfer Remote Server</title></head><body>
68    <h1>Surver - Surfer Remote Server</h1>
69    <b>To connect, run:</b> <code>surfer {}</code><br>
70    <b>Wellen version:</b> {WELLEN_VERSION}<br>
71    <b>Surfer version:</b> {SURFER_VERSION}<br>
72    <b>Filename:</b> {}<br>
73    <b>Progress:</b> {progress}<br>
74    </body></html>
75    "#,
76        shared.url, shared.filename
77    )
78}
79
80fn get_hierarchy(shared: Arc<ReadOnly>) -> Result<Vec<u8>> {
81    let mut raw = BINCODE_OPTIONS.serialize(&shared.file_format)?;
82    let mut raw2 = BINCODE_OPTIONS.serialize(&shared.hierarchy)?;
83    raw.append(&mut raw2);
84    let compressed = lz4_flex::compress_prepend_size(&raw);
85    info!(
86        "Sending hierarchy. {} raw, {} compressed.",
87        bytesize::ByteSize::b(raw.len() as u64),
88        bytesize::ByteSize::b(compressed.len() as u64)
89    );
90    Ok(compressed)
91}
92
93async fn get_timetable(state: Arc<RwLock<State>>) -> Result<Vec<u8>> {
94    // poll to see when the time table is available
95    #[allow(unused_assignments)]
96    let mut table = vec![];
97    loop {
98        {
99            let state = state.read().unwrap();
100            if !state.timetable.is_empty() {
101                table = state.timetable.clone();
102                break;
103            }
104        }
105        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
106    }
107    let raw_size = table.len() * std::mem::size_of::<Time>();
108    let compressed = BINCODE_OPTIONS.serialize(&CompressedTimeTable::compress(&table))?;
109    info!(
110        "Sending timetable. {} raw, {} compressed.",
111        bytesize::ByteSize::b(raw_size as u64),
112        bytesize::ByteSize::b(compressed.len() as u64)
113    );
114    Ok(compressed)
115}
116
117fn get_status(shared: Arc<ReadOnly>) -> Result<Vec<u8>> {
118    let status = Status {
119        bytes: shared.body_len + shared.header_len,
120        bytes_loaded: shared.body_progress.load(Ordering::SeqCst) + shared.header_len,
121        filename: shared.filename.clone(),
122        wellen_version: WELLEN_VERSION.to_string(),
123        surfer_version: SURFER_VERSION.to_string(),
124        file_format: shared.file_format,
125    };
126    Ok(serde_json::to_vec(&status)?)
127}
128
129async fn get_signals(
130    state: Arc<RwLock<State>>,
131    tx: Sender<SignalRequest>,
132    id_strings: &[&str],
133) -> Result<Vec<u8>> {
134    let mut ids = Vec::with_capacity(id_strings.len());
135    for id in id_strings.iter() {
136        ids.push(SignalRef::from_index(id.parse::<u64>()? as usize).unwrap());
137    }
138
139    if ids.is_empty() {
140        return Ok(vec![]);
141    }
142    let num_ids = ids.len();
143
144    // send request to background thread
145    tx.send(ids.clone())?;
146
147    // poll to see when all our ids are returned
148    let mut data = vec![];
149    leb128::write::unsigned(&mut data, num_ids as u64)?;
150    let mut raw_size = 0;
151    loop {
152        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
153        {
154            let state = state.read().unwrap();
155            if ids.iter().all(|id| state.signals.contains_key(id)) {
156                for id in ids {
157                    let signal = &state.signals[&id];
158                    raw_size += BINCODE_OPTIONS.serialize(signal)?.len();
159                    let comp = CompressedSignal::compress(signal);
160                    data.append(&mut BINCODE_OPTIONS.serialize(&comp)?);
161                }
162                break;
163            }
164        };
165    }
166    info!(
167        "Sending {} signals. {} raw, {} compressed.",
168        num_ids,
169        bytesize::ByteSize::b(raw_size as u64),
170        bytesize::ByteSize::b(data.len() as u64)
171    );
172    Ok(data)
173}
174
175const CONTENT_TYPE: &str = "Content-Type";
176const JSON_MIME: &str = "application/json";
177
178trait DefaultHeader {
179    fn default_header(self) -> Self;
180}
181
182impl DefaultHeader for hyper::http::response::Builder {
183    fn default_header(self) -> Self {
184        self.header(HTTP_SERVER_KEY, HTTP_SERVER_VALUE_SURFER)
185            .header(X_WELLEN_VERSION, WELLEN_VERSION)
186            .header(X_SURFER_VERSION, SURFER_VERSION)
187            .header("Cache-Control", "no-cache")
188    }
189}
190
191async fn handle_cmd(
192    state: Arc<RwLock<State>>,
193    shared: Arc<ReadOnly>,
194    tx: Sender<SignalRequest>,
195    cmd: &str,
196    args: &[&str],
197) -> Result<Response<Full<Bytes>>> {
198    let response = match (cmd, args) {
199        ("get_status", []) => {
200            let body = get_status(shared)?;
201            Response::builder()
202                .status(StatusCode::OK)
203                .header(CONTENT_TYPE, JSON_MIME)
204                .default_header()
205                .body(Full::from(body))
206        }
207        ("get_hierarchy", []) => {
208            let body = get_hierarchy(shared)?;
209            Response::builder()
210                .status(StatusCode::OK)
211                .default_header()
212                .body(Full::from(body))
213        }
214        ("get_time_table", []) => {
215            let body = get_timetable(state).await?;
216            Response::builder()
217                .status(StatusCode::OK)
218                .default_header()
219                .body(Full::from(body))
220        }
221        ("get_signals", id_strings) => {
222            let body = get_signals(state, tx, id_strings).await?;
223            Response::builder()
224                .status(StatusCode::OK)
225                .default_header()
226                .body(Full::from(body))
227        }
228        _ => {
229            // unknown command or unexpected number of arguments
230            Response::builder()
231                .status(StatusCode::NOT_FOUND)
232                .body(Full::from(vec![]))
233        }
234    };
235    Ok(response?)
236}
237
238async fn handle(
239    state: Arc<RwLock<State>>,
240    shared: Arc<ReadOnly>,
241    tx: Sender<SignalRequest>,
242    req: Request<hyper::body::Incoming>,
243) -> Result<Response<Full<Bytes>>> {
244    // check to see if the correct token was received
245    let path_parts = req.uri().path().split('/').skip(1).collect::<Vec<_>>();
246
247    // check token
248    if let Some(provided_token) = path_parts.first() {
249        if *provided_token != shared.token {
250            warn!(
251                "Received request with invalid token: {provided_token} != {}\n{:?}",
252                shared.token,
253                req.uri()
254            );
255            return Ok(Response::builder()
256                .status(StatusCode::NOT_FOUND)
257                .body(Full::from(vec![]))?);
258        }
259    } else {
260        // no token
261        warn!("Received request with no token: {:?}", req.uri());
262        return Ok(Response::builder()
263            .status(StatusCode::NOT_FOUND)
264            .body(Full::from(vec![]))?);
265    }
266
267    // check command
268    let response = if let Some(cmd) = path_parts.get(1) {
269        handle_cmd(state, shared, tx, cmd, &path_parts[2..]).await?
270    } else {
271        // valid token, but no command => return info
272        let body = Full::from(get_info_page(shared));
273        Response::builder()
274            .status(StatusCode::OK)
275            .default_header()
276            .body(body)?
277    };
278
279    Ok(response)
280}
281
282const MIN_TOKEN_LEN: usize = 8;
283const RAND_TOKEN_LEN: usize = 24;
284
285pub type ServerStartedFlag = Arc<std::sync::atomic::AtomicBool>;
286
287pub async fn server_main(
288    port: u16,
289    token: Option<String>,
290    filename: String,
291    started: Option<ServerStartedFlag>,
292) -> Result<()> {
293    // if no token was provided, we generate one
294    let token = token.unwrap_or_else(|| {
295        // generate a random ASCII token
296        repeat_with(fastrand::alphanumeric)
297            .take(RAND_TOKEN_LEN)
298            .collect()
299    });
300
301    if token.len() < MIN_TOKEN_LEN {
302        bail!("Token `{token}` is too short. At least {MIN_TOKEN_LEN} characters are required!");
303    }
304
305    // load file
306    let start_read_header = web_time::Instant::now();
307    let header_result =
308        wellen::viewers::read_header_from_file(filename.clone(), &WELLEN_SURFER_DEFAULT_OPTIONS)
309            .map_err(|e| anyhow!("{e:?}"))
310            .with_context(|| format!("Failed to parse wave file: {filename}"))?;
311    info!(
312        "Loaded header of {filename} in {:?}",
313        start_read_header.elapsed()
314    );
315    let addr = SocketAddr::from(([127, 0, 0, 1], port));
316
317    // immutable read-only data
318    let url = format!("http://{addr:?}/{token}");
319    let url_copy = url.clone();
320    let token_copy = token.clone();
321    let shared = Arc::new(ReadOnly {
322        url,
323        token,
324        filename,
325        hierarchy: header_result.hierarchy,
326        file_format: header_result.file_format,
327        header_len: 0, // FIXME: get value from wellen
328        body_len: header_result.body_len,
329        body_progress: Arc::new(AtomicU64::new(0)),
330    });
331    // state can be written by the loading thread
332    let state = Arc::new(RwLock::new(State::default()));
333    // channel to communicate with loader
334    let (tx, rx) = std::sync::mpsc::channel::<SignalRequest>();
335    // start work thread
336    let shared_2 = shared.clone();
337    let state_2 = state.clone();
338    std::thread::spawn(move || loader(shared_2, header_result.body, state_2, rx));
339
340    // print out status
341    info!("Starting server on {addr:?}. To use:");
342    info!("1. Setup an ssh tunnel: -L {port}:localhost:{port}");
343    let hostname = whoami::fallible::hostname();
344    if let Ok(hostname) = hostname.as_ref() {
345        let username = whoami::username();
346        info!(
347            "   The correct command may be: ssh -L {port}:localhost:{port} {username}@{hostname} "
348        );
349    }
350
351    info!("2. Start Surfer: surfer {url_copy} ");
352    if let Ok(hostname) = hostname {
353        let hosturl = format!("http://{hostname}:{port}/{token_copy}");
354        info!("or, if the host is directly accessible:");
355        info!("1. Start Surfer: surfer {hosturl} ");
356    }
357    // create listener and serve it
358    let listener = TcpListener::bind(&addr).await?;
359
360    // we have started the server
361    if let Some(started) = started {
362        started.store(true, Ordering::SeqCst);
363    }
364
365    // main server loop
366    loop {
367        let (stream, _) = listener.accept().await?;
368        let io = TokioIo::new(stream);
369
370        let state = state.clone();
371        let shared = shared.clone();
372        let tx = tx.clone();
373        tokio::task::spawn(async move {
374            let service =
375                service_fn(move |req| handle(state.clone(), shared.clone(), tx.clone(), req));
376            if let Err(e) = http1::Builder::new().serve_connection(io, service).await {
377                error!("server error: {}", e);
378            }
379        });
380    }
381}
382
383/// Thread that loads the body and signals.
384fn loader<R: BufRead + Seek + Sync + Send + 'static>(
385    shared: Arc<ReadOnly>,
386    body_cont: viewers::ReadBodyContinuation<R>,
387    state: Arc<RwLock<State>>,
388    rx: std::sync::mpsc::Receiver<SignalRequest>,
389) -> Result<()> {
390    // load the body of the file
391    let start_load_body = web_time::Instant::now();
392    let body_result = viewers::read_body(
393        body_cont,
394        &shared.hierarchy,
395        Some(shared.body_progress.clone()),
396    )
397    .map_err(|e| anyhow!("{e:?}"))
398    .with_context(|| format!("Failed to parse body of wave file: {}", shared.filename))?;
399    info!("Loaded body in {:?}", start_load_body.elapsed());
400
401    // update state with body results
402    {
403        let mut state = state.write().unwrap();
404        state.timetable = body_result.time_table;
405    }
406    // source is private, only owned by us
407    let mut source = body_result.source;
408
409    // process requests for signals to be loaded
410    loop {
411        let ids = rx.recv()?;
412
413        // make sure that we do not load signals that have already been loaded
414        let mut filtered_ids = {
415            let state_lock = state.read().unwrap();
416            ids.iter()
417                .filter(|id| !state_lock.signals.contains_key(id))
418                .cloned()
419                .collect::<Vec<_>>()
420        };
421
422        // check if there is anything left to do
423        if filtered_ids.is_empty() {
424            continue;
425        }
426
427        // load signals without holding the lock
428        filtered_ids.sort();
429        filtered_ids.dedup();
430        let result = source.load_signals(&filtered_ids, &shared.hierarchy, true);
431
432        // store signals
433        {
434            let mut state = state.write().unwrap();
435            for (id, signal) in result {
436                state.signals.insert(id, signal);
437            }
438        }
439    }
440
441    // unreachable!("the user needs to terminate the server")
442}