1use bincode::Options;
3use eyre::{anyhow, bail, Context, Result};
4use http_body_util::Full;
5use hyper::body::Bytes;
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper::{Request, Response, StatusCode};
9use hyper_util::rt::TokioIo;
10use log::{error, info, warn};
11use std::collections::HashMap;
12use std::io::{BufRead, Seek};
13use std::iter::repeat_with;
14use std::net::SocketAddr;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::mpsc::Sender;
17use std::sync::{Arc, RwLock};
18use tokio::net::TcpListener;
19use wellen::{
20 viewers, CompressedSignal, CompressedTimeTable, FileFormat, Hierarchy, Signal, SignalRef, Time,
21};
22
23use crate::{
24 Status, BINCODE_OPTIONS, HTTP_SERVER_KEY, HTTP_SERVER_VALUE_SURFER, SURFER_VERSION,
25 WELLEN_SURFER_DEFAULT_OPTIONS, WELLEN_VERSION, X_SURFER_VERSION, X_WELLEN_VERSION,
26};
27
28struct ReadOnly {
29 url: String,
30 token: String,
31 filename: String,
32 hierarchy: Hierarchy,
33 file_format: FileFormat,
34 header_len: u64,
35 body_len: u64,
36 body_progress: Arc<AtomicU64>,
37}
38
39#[derive(Default)]
40struct State {
41 timetable: Vec<Time>,
42 signals: HashMap<SignalRef, Signal>,
43}
44
45type SignalRequest = Vec<SignalRef>;
46
47fn get_info_page(shared: Arc<ReadOnly>) -> String {
48 let bytes_loaded = shared.body_progress.load(Ordering::SeqCst);
49
50 let progress = if bytes_loaded == shared.body_len {
51 format!(
52 "{} loaded",
53 bytesize::ByteSize::b(shared.body_len + shared.header_len)
54 )
55 } else {
56 format!(
57 "{} / {}",
58 bytesize::ByteSize::b(bytes_loaded + shared.header_len),
59 bytesize::ByteSize::b(shared.body_len + shared.header_len)
60 )
61 };
62
63 format!(
64 r#"
65 <!DOCTYPE html><html lang="en">
66 <head><title>Surver - Surfer Remote Server</title></head><body>
67 <h1>Surver - Surfer Remote Server</h1>
68 <b>To connect, run:</b> <code>surfer {}</code><br>
69 <b>Wellen version:</b> {WELLEN_VERSION}<br>
70 <b>Surfer version:</b> {SURFER_VERSION}<br>
71 <b>Filename:</b> {}<br>
72 <b>Progress:</b> {progress}<br>
73 </body></html>
74 "#,
75 shared.url, shared.filename
76 )
77}
78
79fn get_hierarchy(shared: Arc<ReadOnly>) -> Result<Vec<u8>> {
80 let mut raw = BINCODE_OPTIONS.serialize(&shared.file_format)?;
81 let mut raw2 = BINCODE_OPTIONS.serialize(&shared.hierarchy)?;
82 raw.append(&mut raw2);
83 let compressed = lz4_flex::compress_prepend_size(&raw);
84 info!(
85 "Sending hierarchy. {} raw, {} compressed.",
86 bytesize::ByteSize::b(raw.len() as u64),
87 bytesize::ByteSize::b(compressed.len() as u64)
88 );
89 Ok(compressed)
90}
91
92async fn get_timetable(state: Arc<RwLock<State>>) -> Result<Vec<u8>> {
93 #[allow(unused_assignments)]
95 let mut table = vec![];
96 loop {
97 {
98 let state = state.read().unwrap();
99 if !state.timetable.is_empty() {
100 table = state.timetable.clone();
101 break;
102 }
103 }
104 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
105 }
106 let raw_size = table.len() * std::mem::size_of::<Time>();
107 let compressed = BINCODE_OPTIONS.serialize(&CompressedTimeTable::compress(&table))?;
108 info!(
109 "Sending timetable. {} raw, {} compressed.",
110 bytesize::ByteSize::b(raw_size as u64),
111 bytesize::ByteSize::b(compressed.len() as u64)
112 );
113 Ok(compressed)
114}
115
116fn get_status(shared: Arc<ReadOnly>) -> Result<Vec<u8>> {
117 let status = Status {
118 bytes: shared.body_len + shared.header_len,
119 bytes_loaded: shared.body_progress.load(Ordering::SeqCst) + shared.header_len,
120 filename: shared.filename.clone(),
121 wellen_version: WELLEN_VERSION.to_string(),
122 surfer_version: SURFER_VERSION.to_string(),
123 file_format: shared.file_format,
124 };
125 Ok(serde_json::to_vec(&status)?)
126}
127
128async fn get_signals(
129 state: Arc<RwLock<State>>,
130 tx: Sender<SignalRequest>,
131 id_strings: &[&str],
132) -> Result<Vec<u8>> {
133 let mut ids = Vec::with_capacity(id_strings.len());
134 for id in id_strings.iter() {
135 ids.push(SignalRef::from_index(id.parse::<u64>()? as usize).unwrap());
136 }
137
138 if ids.is_empty() {
139 return Ok(vec![]);
140 }
141 let num_ids = ids.len();
142
143 tx.send(ids.clone())?;
145
146 let mut data = vec![];
148 leb128::write::unsigned(&mut data, num_ids as u64)?;
149 let mut raw_size = 0;
150 loop {
151 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
152 {
153 let state = state.read().unwrap();
154 if ids.iter().all(|id| state.signals.contains_key(id)) {
155 for id in ids {
156 let signal = &state.signals[&id];
157 raw_size += BINCODE_OPTIONS.serialize(signal)?.len();
158 let comp = CompressedSignal::compress(signal);
159 data.append(&mut BINCODE_OPTIONS.serialize(&comp)?);
160 }
161 break;
162 }
163 };
164 }
165 info!(
166 "Sending {} signals. {} raw, {} compressed.",
167 num_ids,
168 bytesize::ByteSize::b(raw_size as u64),
169 bytesize::ByteSize::b(data.len() as u64)
170 );
171 Ok(data)
172}
173
174const CONTENT_TYPE: &str = "Content-Type";
175const JSON_MIME: &str = "application/json";
176
177trait DefaultHeader {
178 fn default_header(self) -> Self;
179}
180
181impl DefaultHeader for hyper::http::response::Builder {
182 fn default_header(self) -> Self {
183 self.header(HTTP_SERVER_KEY, HTTP_SERVER_VALUE_SURFER)
184 .header(X_WELLEN_VERSION, WELLEN_VERSION)
185 .header(X_SURFER_VERSION, SURFER_VERSION)
186 .header("Cache-Control", "no-cache")
187 }
188}
189
190async fn handle_cmd(
191 state: Arc<RwLock<State>>,
192 shared: Arc<ReadOnly>,
193 tx: Sender<SignalRequest>,
194 cmd: &str,
195 args: &[&str],
196) -> Result<Response<Full<Bytes>>> {
197 let response = match (cmd, args) {
198 ("get_status", []) => {
199 let body = get_status(shared)?;
200 Response::builder()
201 .status(StatusCode::OK)
202 .header(CONTENT_TYPE, JSON_MIME)
203 .default_header()
204 .body(Full::from(body))
205 }
206 ("get_hierarchy", []) => {
207 let body = get_hierarchy(shared)?;
208 Response::builder()
209 .status(StatusCode::OK)
210 .default_header()
211 .body(Full::from(body))
212 }
213 ("get_time_table", []) => {
214 let body = get_timetable(state).await?;
215 Response::builder()
216 .status(StatusCode::OK)
217 .default_header()
218 .body(Full::from(body))
219 }
220 ("get_signals", id_strings) => {
221 let body = get_signals(state, tx, id_strings).await?;
222 Response::builder()
223 .status(StatusCode::OK)
224 .default_header()
225 .body(Full::from(body))
226 }
227 _ => {
228 Response::builder()
230 .status(StatusCode::NOT_FOUND)
231 .body(Full::from(vec![]))
232 }
233 };
234 Ok(response?)
235}
236
237async fn handle(
238 state: Arc<RwLock<State>>,
239 shared: Arc<ReadOnly>,
240 tx: Sender<SignalRequest>,
241 req: Request<hyper::body::Incoming>,
242) -> Result<Response<Full<Bytes>>> {
243 let path_parts = req.uri().path().split('/').skip(1).collect::<Vec<_>>();
245
246 if let Some(provided_token) = path_parts.first() {
248 if *provided_token != shared.token {
249 warn!(
250 "Received request with invalid token: {provided_token} != {}\n{:?}",
251 shared.token,
252 req.uri()
253 );
254 return Ok(Response::builder()
255 .status(StatusCode::NOT_FOUND)
256 .body(Full::from(vec![]))?);
257 }
258 } else {
259 warn!("Received request with no token: {:?}", req.uri());
261 return Ok(Response::builder()
262 .status(StatusCode::NOT_FOUND)
263 .body(Full::from(vec![]))?);
264 }
265
266 let response = if let Some(cmd) = path_parts.get(1) {
268 handle_cmd(state, shared, tx, cmd, &path_parts[2..]).await?
269 } else {
270 let body = Full::from(get_info_page(shared));
272 Response::builder()
273 .status(StatusCode::OK)
274 .default_header()
275 .body(body)?
276 };
277
278 Ok(response)
279}
280
281const MIN_TOKEN_LEN: usize = 8;
282const RAND_TOKEN_LEN: usize = 24;
283
284pub type ServerStartedFlag = Arc<std::sync::atomic::AtomicBool>;
285
286pub async fn server_main(
287 port: u16,
288 token: Option<String>,
289 filename: String,
290 started: Option<ServerStartedFlag>,
291) -> Result<()> {
292 let token = token.unwrap_or_else(|| {
294 repeat_with(fastrand::alphanumeric)
296 .take(RAND_TOKEN_LEN)
297 .collect()
298 });
299
300 if token.len() < MIN_TOKEN_LEN {
301 bail!("Token `{token}` is too short. At least {MIN_TOKEN_LEN} characters are required!");
302 }
303
304 let start_read_header = web_time::Instant::now();
306 let header_result =
307 wellen::viewers::read_header_from_file(filename.clone(), &WELLEN_SURFER_DEFAULT_OPTIONS)
308 .map_err(|e| anyhow!("{e:?}"))
309 .with_context(|| format!("Failed to parse wave file: {filename}"))?;
310 info!(
311 "Loaded header of {filename} in {:?}",
312 start_read_header.elapsed()
313 );
314 let addr = SocketAddr::from(([127, 0, 0, 1], port));
315
316 let url = format!("http://{addr:?}/{token}");
318 let url_copy = url.clone();
319 let token_copy = token.clone();
320 let shared = Arc::new(ReadOnly {
321 url,
322 token,
323 filename,
324 hierarchy: header_result.hierarchy,
325 file_format: header_result.file_format,
326 header_len: 0, body_len: header_result.body_len,
328 body_progress: Arc::new(AtomicU64::new(0)),
329 });
330 let state = Arc::new(RwLock::new(State::default()));
332 let (tx, rx) = std::sync::mpsc::channel::<SignalRequest>();
334 let shared_2 = shared.clone();
336 let state_2 = state.clone();
337 std::thread::spawn(move || loader(shared_2, header_result.body, state_2, rx));
338
339 info!("Starting server on {addr:?}. To use:");
341 info!("1. Setup an ssh tunnel: -L {port}:localhost:{port}");
342 let hostname = whoami::fallible::hostname();
343 if let Ok(hostname) = hostname.as_ref() {
344 let username = whoami::username();
345 info!(
346 " The correct command may be: ssh -L {port}:localhost:{port} {username}@{hostname} "
347 );
348 }
349
350 info!("2. Start Surfer: surfer {url_copy} ");
351 if let Ok(hostname) = hostname {
352 let hosturl = format!("http://{hostname}:{port}/{token_copy}");
353 info!("or, if the host is directly accessible:");
354 info!("1. Start Surfer: surfer {hosturl} ");
355 }
356 let listener = TcpListener::bind(&addr).await?;
358
359 if let Some(started) = started {
361 started.store(true, Ordering::SeqCst);
362 }
363
364 loop {
366 let (stream, _) = listener.accept().await?;
367 let io = TokioIo::new(stream);
368
369 let state = state.clone();
370 let shared = shared.clone();
371 let tx = tx.clone();
372 tokio::task::spawn(async move {
373 let service =
374 service_fn(move |req| handle(state.clone(), shared.clone(), tx.clone(), req));
375 if let Err(e) = http1::Builder::new().serve_connection(io, service).await {
376 error!("server error: {e}");
377 }
378 });
379 }
380}
381
382fn loader<R: BufRead + Seek + Sync + Send + 'static>(
384 shared: Arc<ReadOnly>,
385 body_cont: viewers::ReadBodyContinuation<R>,
386 state: Arc<RwLock<State>>,
387 rx: std::sync::mpsc::Receiver<SignalRequest>,
388) -> Result<()> {
389 let start_load_body = web_time::Instant::now();
391 let body_result = viewers::read_body(
392 body_cont,
393 &shared.hierarchy,
394 Some(shared.body_progress.clone()),
395 )
396 .map_err(|e| anyhow!("{e:?}"))
397 .with_context(|| format!("Failed to parse body of wave file: {}", shared.filename))?;
398 info!("Loaded body in {:?}", start_load_body.elapsed());
399
400 {
402 let mut state = state.write().unwrap();
403 state.timetable = body_result.time_table;
404 }
405 let mut source = body_result.source;
407
408 loop {
410 let ids = rx.recv()?;
411
412 let mut filtered_ids = {
414 let state_lock = state.read().unwrap();
415 ids.iter()
416 .filter(|id| !state_lock.signals.contains_key(id))
417 .cloned()
418 .collect::<Vec<_>>()
419 };
420
421 if filtered_ids.is_empty() {
423 continue;
424 }
425
426 filtered_ids.sort();
428 filtered_ids.dedup();
429 let result = source.load_signals(&filtered_ids, &shared.hierarchy, true);
430
431 {
433 let mut state = state.write().unwrap();
434 for (id, signal) in result {
435 state.signals.insert(id, signal);
436 }
437 }
438 }
439
440 }