spade_typeinference/
replacement.rs

1use std::{cell::RefCell, collections::BTreeMap};
2
3use serde::{Deserialize, Serialize};
4
5use crate::equation::TypeVarID;
6
7#[derive(Clone, Serialize, Deserialize)]
8pub struct Replacements {
9    replacements: RefCell<BTreeMap<TypeVarID, TypeVarID>>,
10}
11
12impl Replacements {
13    fn new() -> Self {
14        Replacements {
15            replacements: RefCell::new(BTreeMap::new()),
16        }
17    }
18}
19
20#[derive(Clone, Serialize, Deserialize)]
21pub struct ReplacementStack {
22    inner: Vec<Replacements>,
23
24    lookup_steps: RefCell<BTreeMap<usize, usize>>,
25}
26
27impl ReplacementStack {
28    pub fn new() -> Self {
29        Self {
30            inner: vec![Replacements::new()],
31            lookup_steps: RefCell::new(BTreeMap::new()),
32        }
33    }
34
35    pub fn push(&mut self) {
36        self.inner.push(Replacements::new());
37    }
38
39    pub fn discard_top(&mut self) {
40        self.inner.pop();
41    }
42
43    pub fn insert(&mut self, from: TypeVarID, to: TypeVarID) {
44        self.inner
45            .last_mut()
46            .expect("there was no map in the replacement stack")
47            .replacements
48            .borrow_mut()
49            .insert(from, to);
50    }
51
52    pub fn get(&self, mut key: TypeVarID) -> TypeVarID {
53        let top = self
54            .inner
55            .last()
56            .expect("Did not have an entry in the replacement stack");
57
58        // store all nodes in the chain we're walking on
59        let mut seen = Vec::new();
60        let mut replacements = top.replacements.borrow_mut();
61        while let Some(target) = replacements.get(&key) {
62            seen.push(key);
63            key = *target;
64        }
65        let target = key;
66        // update all of them to the end of the chain
67        for key in seen {
68            replacements.insert(key, target);
69        }
70        target
71    }
72
73    pub fn all(&self) -> Vec<&RefCell<BTreeMap<TypeVarID, TypeVarID>>> {
74        self.inner.iter().map(|var| &var.replacements).collect()
75    }
76}