1use bincode::Options;
3use color_eyre::eyre::{anyhow, bail, Context};
4use color_eyre::Result;
5use http_body_util::Full;
6use hyper::body::Bytes;
7use hyper::server::conn::http1;
8use hyper::service::service_fn;
9use hyper::{Request, Response, StatusCode};
10use hyper_util::rt::TokioIo;
11use log::{error, info, warn};
12use std::collections::HashMap;
13use std::io::{BufRead, Seek};
14use std::iter::repeat_with;
15use std::net::SocketAddr;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::mpsc::Sender;
18use std::sync::{Arc, RwLock};
19use tokio::net::TcpListener;
20use wellen::{
21 viewers, CompressedSignal, CompressedTimeTable, FileFormat, Hierarchy, Signal, SignalRef, Time,
22};
23
24use crate::{
25 Status, BINCODE_OPTIONS, HTTP_SERVER_KEY, HTTP_SERVER_VALUE_SURFER, SURFER_VERSION,
26 WELLEN_SURFER_DEFAULT_OPTIONS, WELLEN_VERSION, X_SURFER_VERSION, X_WELLEN_VERSION,
27};
28
29struct ReadOnly {
30 url: String,
31 token: String,
32 filename: String,
33 hierarchy: Hierarchy,
34 file_format: FileFormat,
35 header_len: u64,
36 body_len: u64,
37 body_progress: Arc<AtomicU64>,
38}
39
40#[derive(Default)]
41struct State {
42 timetable: Vec<Time>,
43 signals: HashMap<SignalRef, Signal>,
44}
45
46type SignalRequest = Vec<SignalRef>;
47
48fn get_info_page(shared: Arc<ReadOnly>) -> String {
49 let bytes_loaded = shared.body_progress.load(Ordering::SeqCst);
50
51 let progress = if bytes_loaded == shared.body_len {
52 format!(
53 "{} loaded",
54 bytesize::ByteSize::b(shared.body_len + shared.header_len)
55 )
56 } else {
57 format!(
58 "{} / {}",
59 bytesize::ByteSize::b(bytes_loaded + shared.header_len),
60 bytesize::ByteSize::b(shared.body_len + shared.header_len)
61 )
62 };
63
64 format!(
65 r#"
66 <!DOCTYPE html><html lang="en">
67 <head><title>Surver - Surfer Remote Server</title></head><body>
68 <h1>Surver - Surfer Remote Server</h1>
69 <b>To connect, run:</b> <code>surfer {}</code><br>
70 <b>Wellen version:</b> {WELLEN_VERSION}<br>
71 <b>Surfer version:</b> {SURFER_VERSION}<br>
72 <b>Filename:</b> {}<br>
73 <b>Progress:</b> {progress}<br>
74 </body></html>
75 "#,
76 shared.url, shared.filename
77 )
78}
79
80fn get_hierarchy(shared: Arc<ReadOnly>) -> Result<Vec<u8>> {
81 let mut raw = BINCODE_OPTIONS.serialize(&shared.file_format)?;
82 let mut raw2 = BINCODE_OPTIONS.serialize(&shared.hierarchy)?;
83 raw.append(&mut raw2);
84 let compressed = lz4_flex::compress_prepend_size(&raw);
85 info!(
86 "Sending hierarchy. {} raw, {} compressed.",
87 bytesize::ByteSize::b(raw.len() as u64),
88 bytesize::ByteSize::b(compressed.len() as u64)
89 );
90 Ok(compressed)
91}
92
93async fn get_timetable(state: Arc<RwLock<State>>) -> Result<Vec<u8>> {
94 #[allow(unused_assignments)]
96 let mut table = vec![];
97 loop {
98 {
99 let state = state.read().unwrap();
100 if !state.timetable.is_empty() {
101 table = state.timetable.clone();
102 break;
103 }
104 }
105 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
106 }
107 let raw_size = table.len() * std::mem::size_of::<Time>();
108 let compressed = BINCODE_OPTIONS.serialize(&CompressedTimeTable::compress(&table))?;
109 info!(
110 "Sending timetable. {} raw, {} compressed.",
111 bytesize::ByteSize::b(raw_size as u64),
112 bytesize::ByteSize::b(compressed.len() as u64)
113 );
114 Ok(compressed)
115}
116
117fn get_status(shared: Arc<ReadOnly>) -> Result<Vec<u8>> {
118 let status = Status {
119 bytes: shared.body_len + shared.header_len,
120 bytes_loaded: shared.body_progress.load(Ordering::SeqCst) + shared.header_len,
121 filename: shared.filename.clone(),
122 wellen_version: WELLEN_VERSION.to_string(),
123 surfer_version: SURFER_VERSION.to_string(),
124 file_format: shared.file_format,
125 };
126 Ok(serde_json::to_vec(&status)?)
127}
128
129async fn get_signals(
130 state: Arc<RwLock<State>>,
131 tx: Sender<SignalRequest>,
132 id_strings: &[&str],
133) -> Result<Vec<u8>> {
134 let mut ids = Vec::with_capacity(id_strings.len());
135 for id in id_strings.iter() {
136 ids.push(SignalRef::from_index(id.parse::<u64>()? as usize).unwrap());
137 }
138
139 if ids.is_empty() {
140 return Ok(vec![]);
141 }
142 let num_ids = ids.len();
143
144 tx.send(ids.clone())?;
146
147 let mut data = vec![];
149 leb128::write::unsigned(&mut data, num_ids as u64)?;
150 let mut raw_size = 0;
151 loop {
152 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
153 {
154 let state = state.read().unwrap();
155 if ids.iter().all(|id| state.signals.contains_key(id)) {
156 for id in ids {
157 let signal = &state.signals[&id];
158 raw_size += BINCODE_OPTIONS.serialize(signal)?.len();
159 let comp = CompressedSignal::compress(signal);
160 data.append(&mut BINCODE_OPTIONS.serialize(&comp)?);
161 }
162 break;
163 }
164 };
165 }
166 info!(
167 "Sending {} signals. {} raw, {} compressed.",
168 num_ids,
169 bytesize::ByteSize::b(raw_size as u64),
170 bytesize::ByteSize::b(data.len() as u64)
171 );
172 Ok(data)
173}
174
175const CONTENT_TYPE: &str = "Content-Type";
176const JSON_MIME: &str = "application/json";
177
178trait DefaultHeader {
179 fn default_header(self) -> Self;
180}
181
182impl DefaultHeader for hyper::http::response::Builder {
183 fn default_header(self) -> Self {
184 self.header(HTTP_SERVER_KEY, HTTP_SERVER_VALUE_SURFER)
185 .header(X_WELLEN_VERSION, WELLEN_VERSION)
186 .header(X_SURFER_VERSION, SURFER_VERSION)
187 .header("Cache-Control", "no-cache")
188 }
189}
190
191async fn handle_cmd(
192 state: Arc<RwLock<State>>,
193 shared: Arc<ReadOnly>,
194 tx: Sender<SignalRequest>,
195 cmd: &str,
196 args: &[&str],
197) -> Result<Response<Full<Bytes>>> {
198 let response = match (cmd, args) {
199 ("get_status", []) => {
200 let body = get_status(shared)?;
201 Response::builder()
202 .status(StatusCode::OK)
203 .header(CONTENT_TYPE, JSON_MIME)
204 .default_header()
205 .body(Full::from(body))
206 }
207 ("get_hierarchy", []) => {
208 let body = get_hierarchy(shared)?;
209 Response::builder()
210 .status(StatusCode::OK)
211 .default_header()
212 .body(Full::from(body))
213 }
214 ("get_time_table", []) => {
215 let body = get_timetable(state).await?;
216 Response::builder()
217 .status(StatusCode::OK)
218 .default_header()
219 .body(Full::from(body))
220 }
221 ("get_signals", id_strings) => {
222 let body = get_signals(state, tx, id_strings).await?;
223 Response::builder()
224 .status(StatusCode::OK)
225 .default_header()
226 .body(Full::from(body))
227 }
228 _ => {
229 Response::builder()
231 .status(StatusCode::NOT_FOUND)
232 .body(Full::from(vec![]))
233 }
234 };
235 Ok(response?)
236}
237
238async fn handle(
239 state: Arc<RwLock<State>>,
240 shared: Arc<ReadOnly>,
241 tx: Sender<SignalRequest>,
242 req: Request<hyper::body::Incoming>,
243) -> Result<Response<Full<Bytes>>> {
244 let path_parts = req.uri().path().split('/').skip(1).collect::<Vec<_>>();
246
247 if let Some(provided_token) = path_parts.first() {
249 if *provided_token != shared.token {
250 warn!(
251 "Received request with invalid token: {provided_token} != {}\n{:?}",
252 shared.token,
253 req.uri()
254 );
255 return Ok(Response::builder()
256 .status(StatusCode::NOT_FOUND)
257 .body(Full::from(vec![]))?);
258 }
259 } else {
260 warn!("Received request with no token: {:?}", req.uri());
262 return Ok(Response::builder()
263 .status(StatusCode::NOT_FOUND)
264 .body(Full::from(vec![]))?);
265 }
266
267 let response = if let Some(cmd) = path_parts.get(1) {
269 handle_cmd(state, shared, tx, cmd, &path_parts[2..]).await?
270 } else {
271 let body = Full::from(get_info_page(shared));
273 Response::builder()
274 .status(StatusCode::OK)
275 .default_header()
276 .body(body)?
277 };
278
279 Ok(response)
280}
281
282const MIN_TOKEN_LEN: usize = 8;
283const RAND_TOKEN_LEN: usize = 24;
284
285pub type ServerStartedFlag = Arc<std::sync::atomic::AtomicBool>;
286
287pub async fn server_main(
288 port: u16,
289 token: Option<String>,
290 filename: String,
291 started: Option<ServerStartedFlag>,
292) -> Result<()> {
293 let token = token.unwrap_or_else(|| {
295 repeat_with(fastrand::alphanumeric)
297 .take(RAND_TOKEN_LEN)
298 .collect()
299 });
300
301 if token.len() < MIN_TOKEN_LEN {
302 bail!("Token `{token}` is too short. At least {MIN_TOKEN_LEN} characters are required!");
303 }
304
305 let start_read_header = web_time::Instant::now();
307 let header_result =
308 wellen::viewers::read_header_from_file(filename.clone(), &WELLEN_SURFER_DEFAULT_OPTIONS)
309 .map_err(|e| anyhow!("{e:?}"))
310 .with_context(|| format!("Failed to parse wave file: {filename}"))?;
311 info!(
312 "Loaded header of {filename} in {:?}",
313 start_read_header.elapsed()
314 );
315 let addr = SocketAddr::from(([127, 0, 0, 1], port));
316
317 let url = format!("http://{addr:?}/{token}");
319 let url_copy = url.clone();
320 let token_copy = token.clone();
321 let shared = Arc::new(ReadOnly {
322 url,
323 token,
324 filename,
325 hierarchy: header_result.hierarchy,
326 file_format: header_result.file_format,
327 header_len: 0, body_len: header_result.body_len,
329 body_progress: Arc::new(AtomicU64::new(0)),
330 });
331 let state = Arc::new(RwLock::new(State::default()));
333 let (tx, rx) = std::sync::mpsc::channel::<SignalRequest>();
335 let shared_2 = shared.clone();
337 let state_2 = state.clone();
338 std::thread::spawn(move || loader(shared_2, header_result.body, state_2, rx));
339
340 info!("Starting server on {addr:?}. To use:");
342 info!("1. Setup an ssh tunnel: -L {port}:localhost:{port}");
343 let hostname = whoami::fallible::hostname();
344 if let Ok(hostname) = hostname.as_ref() {
345 let username = whoami::username();
346 info!(
347 " The correct command may be: ssh -L {port}:localhost:{port} {username}@{hostname} "
348 );
349 }
350
351 info!("2. Start Surfer: surfer {url_copy} ");
352 if let Ok(hostname) = hostname {
353 let hosturl = format!("http://{hostname}:{port}/{token_copy}");
354 info!("or, if the host is directly accessible:");
355 info!("1. Start Surfer: surfer {hosturl} ");
356 }
357 let listener = TcpListener::bind(&addr).await?;
359
360 if let Some(started) = started {
362 started.store(true, Ordering::SeqCst);
363 }
364
365 loop {
367 let (stream, _) = listener.accept().await?;
368 let io = TokioIo::new(stream);
369
370 let state = state.clone();
371 let shared = shared.clone();
372 let tx = tx.clone();
373 tokio::task::spawn(async move {
374 let service =
375 service_fn(move |req| handle(state.clone(), shared.clone(), tx.clone(), req));
376 if let Err(e) = http1::Builder::new().serve_connection(io, service).await {
377 error!("server error: {}", e);
378 }
379 });
380 }
381}
382
383fn loader<R: BufRead + Seek + Sync + Send + 'static>(
385 shared: Arc<ReadOnly>,
386 body_cont: viewers::ReadBodyContinuation<R>,
387 state: Arc<RwLock<State>>,
388 rx: std::sync::mpsc::Receiver<SignalRequest>,
389) -> Result<()> {
390 let start_load_body = web_time::Instant::now();
392 let body_result = viewers::read_body(
393 body_cont,
394 &shared.hierarchy,
395 Some(shared.body_progress.clone()),
396 )
397 .map_err(|e| anyhow!("{e:?}"))
398 .with_context(|| format!("Failed to parse body of wave file: {}", shared.filename))?;
399 info!("Loaded body in {:?}", start_load_body.elapsed());
400
401 {
403 let mut state = state.write().unwrap();
404 state.timetable = body_result.time_table;
405 }
406 let mut source = body_result.source;
408
409 loop {
411 let ids = rx.recv()?;
412
413 let mut filtered_ids = {
415 let state_lock = state.read().unwrap();
416 ids.iter()
417 .filter(|id| !state_lock.signals.contains_key(id))
418 .cloned()
419 .collect::<Vec<_>>()
420 };
421
422 if filtered_ids.is_empty() {
424 continue;
425 }
426
427 filtered_ids.sort();
429 filtered_ids.dedup();
430 let result = source.load_signals(&filtered_ids, &shared.hierarchy, true);
431
432 {
434 let mut state = state.write().unwrap();
435 for (id, signal) in result {
436 state.signals.insert(id, signal);
437 }
438 }
439 }
440
441 }