surver/
server.rs

1//! Handling of external communication in Surver.
2use bincode::Options;
3use eyre::{anyhow, bail, Context, Result};
4use http_body_util::Full;
5use hyper::body::Bytes;
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper::{Request, Response, StatusCode};
9use hyper_util::rt::TokioIo;
10use log::{error, info, warn};
11use std::collections::HashMap;
12use std::io::{BufRead, Seek};
13use std::iter::repeat_with;
14use std::net::SocketAddr;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::mpsc::Sender;
17use std::sync::{Arc, RwLock};
18use tokio::net::TcpListener;
19use wellen::{
20    viewers, CompressedSignal, CompressedTimeTable, FileFormat, Hierarchy, Signal, SignalRef, Time,
21};
22
23use crate::{
24    Status, BINCODE_OPTIONS, HTTP_SERVER_KEY, HTTP_SERVER_VALUE_SURFER, SURFER_VERSION,
25    WELLEN_SURFER_DEFAULT_OPTIONS, WELLEN_VERSION, X_SURFER_VERSION, X_WELLEN_VERSION,
26};
27
28struct ReadOnly {
29    url: String,
30    token: String,
31    filename: String,
32    hierarchy: Hierarchy,
33    file_format: FileFormat,
34    header_len: u64,
35    body_len: u64,
36    body_progress: Arc<AtomicU64>,
37}
38
39#[derive(Default)]
40struct State {
41    timetable: Vec<Time>,
42    signals: HashMap<SignalRef, Signal>,
43}
44
45type SignalRequest = Vec<SignalRef>;
46
47fn get_info_page(shared: Arc<ReadOnly>) -> String {
48    let bytes_loaded = shared.body_progress.load(Ordering::SeqCst);
49
50    let progress = if bytes_loaded == shared.body_len {
51        format!(
52            "{} loaded",
53            bytesize::ByteSize::b(shared.body_len + shared.header_len)
54        )
55    } else {
56        format!(
57            "{} / {}",
58            bytesize::ByteSize::b(bytes_loaded + shared.header_len),
59            bytesize::ByteSize::b(shared.body_len + shared.header_len)
60        )
61    };
62
63    format!(
64        r#"
65    <!DOCTYPE html><html lang="en">
66    <head><title>Surver - Surfer Remote Server</title></head><body>
67    <h1>Surver - Surfer Remote Server</h1>
68    <b>To connect, run:</b> <code>surfer {}</code><br>
69    <b>Wellen version:</b> {WELLEN_VERSION}<br>
70    <b>Surfer version:</b> {SURFER_VERSION}<br>
71    <b>Filename:</b> {}<br>
72    <b>Progress:</b> {progress}<br>
73    </body></html>
74    "#,
75        shared.url, shared.filename
76    )
77}
78
79fn get_hierarchy(shared: Arc<ReadOnly>) -> Result<Vec<u8>> {
80    let mut raw = BINCODE_OPTIONS.serialize(&shared.file_format)?;
81    let mut raw2 = BINCODE_OPTIONS.serialize(&shared.hierarchy)?;
82    raw.append(&mut raw2);
83    let compressed = lz4_flex::compress_prepend_size(&raw);
84    info!(
85        "Sending hierarchy. {} raw, {} compressed.",
86        bytesize::ByteSize::b(raw.len() as u64),
87        bytesize::ByteSize::b(compressed.len() as u64)
88    );
89    Ok(compressed)
90}
91
92async fn get_timetable(state: Arc<RwLock<State>>) -> Result<Vec<u8>> {
93    // poll to see when the time table is available
94    #[allow(unused_assignments)]
95    let mut table = vec![];
96    loop {
97        {
98            let state = state.read().unwrap();
99            if !state.timetable.is_empty() {
100                table = state.timetable.clone();
101                break;
102            }
103        }
104        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
105    }
106    let raw_size = table.len() * std::mem::size_of::<Time>();
107    let compressed = BINCODE_OPTIONS.serialize(&CompressedTimeTable::compress(&table))?;
108    info!(
109        "Sending timetable. {} raw, {} compressed.",
110        bytesize::ByteSize::b(raw_size as u64),
111        bytesize::ByteSize::b(compressed.len() as u64)
112    );
113    Ok(compressed)
114}
115
116fn get_status(shared: Arc<ReadOnly>) -> Result<Vec<u8>> {
117    let status = Status {
118        bytes: shared.body_len + shared.header_len,
119        bytes_loaded: shared.body_progress.load(Ordering::SeqCst) + shared.header_len,
120        filename: shared.filename.clone(),
121        wellen_version: WELLEN_VERSION.to_string(),
122        surfer_version: SURFER_VERSION.to_string(),
123        file_format: shared.file_format,
124    };
125    Ok(serde_json::to_vec(&status)?)
126}
127
128async fn get_signals(
129    state: Arc<RwLock<State>>,
130    tx: Sender<SignalRequest>,
131    id_strings: &[&str],
132) -> Result<Vec<u8>> {
133    let mut ids = Vec::with_capacity(id_strings.len());
134    for id in id_strings.iter() {
135        ids.push(SignalRef::from_index(id.parse::<u64>()? as usize).unwrap());
136    }
137
138    if ids.is_empty() {
139        return Ok(vec![]);
140    }
141    let num_ids = ids.len();
142
143    // send request to background thread
144    tx.send(ids.clone())?;
145
146    // poll to see when all our ids are returned
147    let mut data = vec![];
148    leb128::write::unsigned(&mut data, num_ids as u64)?;
149    let mut raw_size = 0;
150    loop {
151        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
152        {
153            let state = state.read().unwrap();
154            if ids.iter().all(|id| state.signals.contains_key(id)) {
155                for id in ids {
156                    let signal = &state.signals[&id];
157                    raw_size += BINCODE_OPTIONS.serialize(signal)?.len();
158                    let comp = CompressedSignal::compress(signal);
159                    data.append(&mut BINCODE_OPTIONS.serialize(&comp)?);
160                }
161                break;
162            }
163        };
164    }
165    info!(
166        "Sending {} signals. {} raw, {} compressed.",
167        num_ids,
168        bytesize::ByteSize::b(raw_size as u64),
169        bytesize::ByteSize::b(data.len() as u64)
170    );
171    Ok(data)
172}
173
174const CONTENT_TYPE: &str = "Content-Type";
175const JSON_MIME: &str = "application/json";
176
177trait DefaultHeader {
178    fn default_header(self) -> Self;
179}
180
181impl DefaultHeader for hyper::http::response::Builder {
182    fn default_header(self) -> Self {
183        self.header(HTTP_SERVER_KEY, HTTP_SERVER_VALUE_SURFER)
184            .header(X_WELLEN_VERSION, WELLEN_VERSION)
185            .header(X_SURFER_VERSION, SURFER_VERSION)
186            .header("Cache-Control", "no-cache")
187    }
188}
189
190async fn handle_cmd(
191    state: Arc<RwLock<State>>,
192    shared: Arc<ReadOnly>,
193    tx: Sender<SignalRequest>,
194    cmd: &str,
195    args: &[&str],
196) -> Result<Response<Full<Bytes>>> {
197    let response = match (cmd, args) {
198        ("get_status", []) => {
199            let body = get_status(shared)?;
200            Response::builder()
201                .status(StatusCode::OK)
202                .header(CONTENT_TYPE, JSON_MIME)
203                .default_header()
204                .body(Full::from(body))
205        }
206        ("get_hierarchy", []) => {
207            let body = get_hierarchy(shared)?;
208            Response::builder()
209                .status(StatusCode::OK)
210                .default_header()
211                .body(Full::from(body))
212        }
213        ("get_time_table", []) => {
214            let body = get_timetable(state).await?;
215            Response::builder()
216                .status(StatusCode::OK)
217                .default_header()
218                .body(Full::from(body))
219        }
220        ("get_signals", id_strings) => {
221            let body = get_signals(state, tx, id_strings).await?;
222            Response::builder()
223                .status(StatusCode::OK)
224                .default_header()
225                .body(Full::from(body))
226        }
227        _ => {
228            // unknown command or unexpected number of arguments
229            Response::builder()
230                .status(StatusCode::NOT_FOUND)
231                .body(Full::from(vec![]))
232        }
233    };
234    Ok(response?)
235}
236
237async fn handle(
238    state: Arc<RwLock<State>>,
239    shared: Arc<ReadOnly>,
240    tx: Sender<SignalRequest>,
241    req: Request<hyper::body::Incoming>,
242) -> Result<Response<Full<Bytes>>> {
243    // check to see if the correct token was received
244    let path_parts = req.uri().path().split('/').skip(1).collect::<Vec<_>>();
245
246    // check token
247    if let Some(provided_token) = path_parts.first() {
248        if *provided_token != shared.token {
249            warn!(
250                "Received request with invalid token: {provided_token} != {}\n{:?}",
251                shared.token,
252                req.uri()
253            );
254            return Ok(Response::builder()
255                .status(StatusCode::NOT_FOUND)
256                .body(Full::from(vec![]))?);
257        }
258    } else {
259        // no token
260        warn!("Received request with no token: {:?}", req.uri());
261        return Ok(Response::builder()
262            .status(StatusCode::NOT_FOUND)
263            .body(Full::from(vec![]))?);
264    }
265
266    // check command
267    let response = if let Some(cmd) = path_parts.get(1) {
268        handle_cmd(state, shared, tx, cmd, &path_parts[2..]).await?
269    } else {
270        // valid token, but no command => return info
271        let body = Full::from(get_info_page(shared));
272        Response::builder()
273            .status(StatusCode::OK)
274            .default_header()
275            .body(body)?
276    };
277
278    Ok(response)
279}
280
281const MIN_TOKEN_LEN: usize = 8;
282const RAND_TOKEN_LEN: usize = 24;
283
284pub type ServerStartedFlag = Arc<std::sync::atomic::AtomicBool>;
285
286pub async fn server_main(
287    port: u16,
288    token: Option<String>,
289    filename: String,
290    started: Option<ServerStartedFlag>,
291) -> Result<()> {
292    // if no token was provided, we generate one
293    let token = token.unwrap_or_else(|| {
294        // generate a random ASCII token
295        repeat_with(fastrand::alphanumeric)
296            .take(RAND_TOKEN_LEN)
297            .collect()
298    });
299
300    if token.len() < MIN_TOKEN_LEN {
301        bail!("Token `{token}` is too short. At least {MIN_TOKEN_LEN} characters are required!");
302    }
303
304    // load file
305    let start_read_header = web_time::Instant::now();
306    let header_result =
307        wellen::viewers::read_header_from_file(filename.clone(), &WELLEN_SURFER_DEFAULT_OPTIONS)
308            .map_err(|e| anyhow!("{e:?}"))
309            .with_context(|| format!("Failed to parse wave file: {filename}"))?;
310    info!(
311        "Loaded header of {filename} in {:?}",
312        start_read_header.elapsed()
313    );
314    let addr = SocketAddr::from(([127, 0, 0, 1], port));
315
316    // immutable read-only data
317    let url = format!("http://{addr:?}/{token}");
318    let url_copy = url.clone();
319    let token_copy = token.clone();
320    let shared = Arc::new(ReadOnly {
321        url,
322        token,
323        filename,
324        hierarchy: header_result.hierarchy,
325        file_format: header_result.file_format,
326        header_len: 0, // FIXME: get value from wellen
327        body_len: header_result.body_len,
328        body_progress: Arc::new(AtomicU64::new(0)),
329    });
330    // state can be written by the loading thread
331    let state = Arc::new(RwLock::new(State::default()));
332    // channel to communicate with loader
333    let (tx, rx) = std::sync::mpsc::channel::<SignalRequest>();
334    // start work thread
335    let shared_2 = shared.clone();
336    let state_2 = state.clone();
337    std::thread::spawn(move || loader(shared_2, header_result.body, state_2, rx));
338
339    // print out status
340    info!("Starting server on {addr:?}. To use:");
341    info!("1. Setup an ssh tunnel: -L {port}:localhost:{port}");
342    let hostname = whoami::fallible::hostname();
343    if let Ok(hostname) = hostname.as_ref() {
344        let username = whoami::username();
345        info!(
346            "   The correct command may be: ssh -L {port}:localhost:{port} {username}@{hostname} "
347        );
348    }
349
350    info!("2. Start Surfer: surfer {url_copy} ");
351    if let Ok(hostname) = hostname {
352        let hosturl = format!("http://{hostname}:{port}/{token_copy}");
353        info!("or, if the host is directly accessible:");
354        info!("1. Start Surfer: surfer {hosturl} ");
355    }
356    // create listener and serve it
357    let listener = TcpListener::bind(&addr).await?;
358
359    // we have started the server
360    if let Some(started) = started {
361        started.store(true, Ordering::SeqCst);
362    }
363
364    // main server loop
365    loop {
366        let (stream, _) = listener.accept().await?;
367        let io = TokioIo::new(stream);
368
369        let state = state.clone();
370        let shared = shared.clone();
371        let tx = tx.clone();
372        tokio::task::spawn(async move {
373            let service =
374                service_fn(move |req| handle(state.clone(), shared.clone(), tx.clone(), req));
375            if let Err(e) = http1::Builder::new().serve_connection(io, service).await {
376                error!("server error: {e}");
377            }
378        });
379    }
380}
381
382/// Thread that loads the body and signals.
383fn loader<R: BufRead + Seek + Sync + Send + 'static>(
384    shared: Arc<ReadOnly>,
385    body_cont: viewers::ReadBodyContinuation<R>,
386    state: Arc<RwLock<State>>,
387    rx: std::sync::mpsc::Receiver<SignalRequest>,
388) -> Result<()> {
389    // load the body of the file
390    let start_load_body = web_time::Instant::now();
391    let body_result = viewers::read_body(
392        body_cont,
393        &shared.hierarchy,
394        Some(shared.body_progress.clone()),
395    )
396    .map_err(|e| anyhow!("{e:?}"))
397    .with_context(|| format!("Failed to parse body of wave file: {}", shared.filename))?;
398    info!("Loaded body in {:?}", start_load_body.elapsed());
399
400    // update state with body results
401    {
402        let mut state = state.write().unwrap();
403        state.timetable = body_result.time_table;
404    }
405    // source is private, only owned by us
406    let mut source = body_result.source;
407
408    // process requests for signals to be loaded
409    loop {
410        let ids = rx.recv()?;
411
412        // make sure that we do not load signals that have already been loaded
413        let mut filtered_ids = {
414            let state_lock = state.read().unwrap();
415            ids.iter()
416                .filter(|id| !state_lock.signals.contains_key(id))
417                .cloned()
418                .collect::<Vec<_>>()
419        };
420
421        // check if there is anything left to do
422        if filtered_ids.is_empty() {
423            continue;
424        }
425
426        // load signals without holding the lock
427        filtered_ids.sort();
428        filtered_ids.dedup();
429        let result = source.load_signals(&filtered_ids, &shared.hierarchy, true);
430
431        // store signals
432        {
433            let mut state = state.write().unwrap();
434            for (id, signal) in result {
435                state.signals.insert(id, signal);
436            }
437        }
438    }
439
440    // unreachable!("the user needs to terminate the server")
441}