spade_typeinference/
constraints.rs

1use num::BigInt;
2use serde::{Deserialize, Serialize};
3use spade_common::location_info::{Loc, WithLocation};
4use spade_types::KnownType;
5
6use crate::{
7    equation::{TypeVar, TypeVarID},
8    TypeState,
9};
10
11#[derive(Debug, Clone)]
12pub enum ConstraintExpr {
13    Bool(bool),
14    Integer(BigInt),
15    Var(TypeVarID),
16    Sum(Box<ConstraintExpr>, Box<ConstraintExpr>),
17    Difference(Box<ConstraintExpr>, Box<ConstraintExpr>),
18    Product(Box<ConstraintExpr>, Box<ConstraintExpr>),
19    Div(Box<ConstraintExpr>, Box<ConstraintExpr>),
20    Mod(Box<ConstraintExpr>, Box<ConstraintExpr>),
21    Sub(Box<ConstraintExpr>),
22    Eq(Box<ConstraintExpr>, Box<ConstraintExpr>),
23    NotEq(Box<ConstraintExpr>, Box<ConstraintExpr>),
24    /// The number of bits required to represent the specified number. In practice
25    /// inner.log2().floor()+1
26    UintBitsToRepresent(Box<ConstraintExpr>),
27}
28impl WithLocation for ConstraintExpr {}
29
30impl ConstraintExpr {
31    pub fn debug_display(&self, type_state: &TypeState) -> String {
32        match self {
33            ConstraintExpr::Bool(b) => format!("{b}"),
34            ConstraintExpr::Integer(v) => format!("{v}"),
35            ConstraintExpr::Var(type_var_id) => {
36                format!("{}", type_var_id.debug_resolve(type_state))
37            }
38            ConstraintExpr::Sum(lhs, rhs) => {
39                format!(
40                    "({} + {})",
41                    lhs.debug_display(type_state),
42                    rhs.debug_display(type_state)
43                )
44            }
45            ConstraintExpr::Difference(lhs, rhs) => {
46                format!(
47                    "({} - {})",
48                    lhs.debug_display(type_state),
49                    rhs.debug_display(type_state)
50                )
51            }
52            ConstraintExpr::Product(lhs, rhs) => {
53                format!(
54                    "({} * {})",
55                    lhs.debug_display(type_state),
56                    rhs.debug_display(type_state)
57                )
58            }
59            ConstraintExpr::Div(lhs, rhs) => {
60                format!(
61                    "({} / {})",
62                    lhs.debug_display(type_state),
63                    rhs.debug_display(type_state)
64                )
65            }
66            ConstraintExpr::Mod(lhs, rhs) => {
67                format!(
68                    "({} % {})",
69                    lhs.debug_display(type_state),
70                    rhs.debug_display(type_state)
71                )
72            }
73            ConstraintExpr::Sub(lhs) => {
74                format!("(-{})", lhs.debug_display(type_state))
75            }
76            ConstraintExpr::Eq(lhs, rhs) => {
77                format!(
78                    "({} == {})",
79                    lhs.debug_display(type_state),
80                    rhs.debug_display(type_state)
81                )
82            }
83            ConstraintExpr::NotEq(lhs, rhs) => {
84                format!(
85                    "({} != {})",
86                    lhs.debug_display(type_state),
87                    rhs.debug_display(type_state)
88                )
89            }
90            ConstraintExpr::UintBitsToRepresent(c) => {
91                format!("uint_bits_to_fit({})", c.debug_display(type_state))
92            }
93        }
94    }
95}
96
97impl ConstraintExpr {
98    /// Evaluates the ConstraintExpr returning a new simplified form
99    fn evaluate(&self, type_state: &TypeState) -> ConstraintExpr {
100        let binop =
101            |lhs: &ConstraintExpr, rhs: &ConstraintExpr, op: &dyn Fn(BigInt, BigInt) -> BigInt| {
102                match (lhs.evaluate(type_state), rhs.evaluate(type_state)) {
103                    (ConstraintExpr::Integer(l), ConstraintExpr::Integer(r)) => {
104                        ConstraintExpr::Integer(op(l, r))
105                    }
106                    _ => self.clone(),
107                }
108            };
109        match self {
110            ConstraintExpr::Integer(_) => self.clone(),
111            ConstraintExpr::Bool(_) => self.clone(),
112            ConstraintExpr::Var(v) => match v.resolve(type_state) {
113                TypeVar::Known(_, known_type, _) => match known_type {
114                    KnownType::Integer(i) => ConstraintExpr::Integer(i.clone()),
115                    KnownType::Bool(b) => ConstraintExpr::Bool(b.clone()),
116                    KnownType::Error => self.clone(),
117                    KnownType::Named(_)
118                    | KnownType::Tuple
119                    | KnownType::Array
120                    | KnownType::Wire
121                    | KnownType::Inverted => {
122                        panic!("Inferred non-integer or bool for constraint variable")
123                    }
124                },
125                TypeVar::Unknown(_, _, _, _) => self.clone(),
126            },
127            ConstraintExpr::Sum(lhs, rhs) => binop(lhs, rhs, &|l, r| l + r),
128            ConstraintExpr::Difference(lhs, rhs) => binop(lhs, rhs, &|l, r| l - r),
129            ConstraintExpr::Product(lhs, rhs) => binop(lhs, rhs, &|l, r| l * r),
130            ConstraintExpr::Div(lhs, rhs) => binop(lhs, rhs, &|l, r| l / r),
131            ConstraintExpr::Mod(lhs, rhs) => binop(lhs, rhs, &|l, r| l % r),
132            ConstraintExpr::Sub(inner) => match inner.evaluate(type_state) {
133                ConstraintExpr::Integer(val) => ConstraintExpr::Integer(-val),
134                _ => self.clone(),
135            },
136            ConstraintExpr::Eq(lhs, rhs) => {
137                match (lhs.evaluate(type_state), rhs.evaluate(type_state)) {
138                    (ConstraintExpr::Integer(l), ConstraintExpr::Integer(r)) => {
139                        ConstraintExpr::Bool(l == r)
140                    }
141                    _ => self.clone(),
142                }
143            }
144            ConstraintExpr::NotEq(lhs, rhs) => {
145                match (lhs.evaluate(type_state), rhs.evaluate(type_state)) {
146                    (ConstraintExpr::Integer(l), ConstraintExpr::Integer(r)) => {
147                        ConstraintExpr::Bool(l != r)
148                    }
149                    _ => self.clone(),
150                }
151            }
152            ConstraintExpr::UintBitsToRepresent(inner) => match inner.evaluate(type_state) {
153                ConstraintExpr::Integer(val) => ConstraintExpr::Integer(val.bits().into()),
154                _ => self.clone(),
155            },
156        }
157    }
158
159    pub fn with_context(
160        self,
161        replaces: &TypeVarID,
162        inside: &TypeVarID,
163        source: ConstraintSource,
164    ) -> ConstraintRhs {
165        ConstraintRhs {
166            constraint: self,
167            context: ConstraintContext {
168                replaces: replaces.clone(),
169                inside: inside.clone(),
170                source,
171            },
172        }
173    }
174}
175
176impl std::ops::Add for ConstraintExpr {
177    type Output = ConstraintExpr;
178
179    fn add(self, rhs: Self) -> Self::Output {
180        ConstraintExpr::Sum(Box::new(self), Box::new(rhs))
181    }
182}
183
184impl std::ops::Sub for ConstraintExpr {
185    type Output = ConstraintExpr;
186
187    fn sub(self, rhs: Self) -> Self::Output {
188        ConstraintExpr::Sum(Box::new(self), Box::new(-rhs))
189    }
190}
191
192impl std::ops::Neg for ConstraintExpr {
193    type Output = ConstraintExpr;
194
195    fn neg(self) -> Self::Output {
196        ConstraintExpr::Sub(Box::new(self))
197    }
198}
199
200pub fn bits_to_store(inner: ConstraintExpr) -> ConstraintExpr {
201    ConstraintExpr::UintBitsToRepresent(Box::new(inner))
202}
203
204// Shorthand constructors for constraint_expr
205pub fn ce_var(v: &TypeVarID) -> ConstraintExpr {
206    ConstraintExpr::Var(v.clone())
207}
208pub fn ce_int(v: BigInt) -> ConstraintExpr {
209    ConstraintExpr::Integer(v)
210}
211
212#[derive(Debug, Clone, PartialEq)]
213pub enum ConstraintSource {
214    AdditionOutput,
215    MultOutput,
216    ArrayIndexing,
217    MemoryIndexing,
218    Concatenation,
219    PipelineRegOffset { reg: Loc<()>, total: Loc<()> },
220    PipelineRegCount { reg: Loc<()>, total: Loc<()> },
221    PipelineAvailDepth,
222    RangeIndex,
223    RangeIndexOutputSize,
224    ArraySize,
225    TypeLevelIf,
226    Where,
227}
228
229impl std::fmt::Display for ConstraintSource {
230    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        match self {
232            ConstraintSource::AdditionOutput => write!(f, "AdditionOutput"),
233            ConstraintSource::MultOutput => write!(f, "MultiplicationOutput"),
234            ConstraintSource::ArrayIndexing => write!(f, "ArrayIndexing"),
235            ConstraintSource::MemoryIndexing => write!(f, "MemoryIndexing"),
236            ConstraintSource::Concatenation => write!(f, "Concatenation"),
237            ConstraintSource::Where => write!(f, "Where"),
238            ConstraintSource::RangeIndex => write!(f, "RangeIndex"),
239            ConstraintSource::RangeIndexOutputSize => write!(f, "RangeIndexOutputSize"),
240            ConstraintSource::ArraySize => write!(f, "ArraySize"),
241            ConstraintSource::PipelineRegOffset { .. } => write!(f, "PipelineRegOffset"),
242            ConstraintSource::PipelineRegCount { .. } => write!(f, "PipelineRegOffset"),
243            ConstraintSource::PipelineAvailDepth => write!(f, "PipelineAvailDepth"),
244            ConstraintSource::TypeLevelIf => write!(f, "TypeLevelIf"),
245        }
246    }
247}
248
249#[derive(Debug, Clone)]
250pub struct ConstraintRhs {
251    /// The actual constraint
252    pub constraint: ConstraintExpr,
253    pub context: ConstraintContext,
254}
255
256impl ConstraintRhs {
257    pub fn debug_display(&self, type_state: &TypeState) -> String {
258        self.constraint.debug_display(type_state)
259    }
260}
261
262impl WithLocation for ConstraintRhs {}
263
264#[derive(Clone, Serialize, Deserialize)]
265pub struct TypeConstraints {
266    #[serde(skip)]
267    pub inner: Vec<(TypeVarID, Loc<ConstraintRhs>)>,
268}
269
270impl TypeConstraints {
271    pub fn new() -> Self {
272        Self { inner: vec![] }
273    }
274
275    pub fn add_int_constraint(&mut self, lhs: TypeVarID, rhs: Loc<ConstraintRhs>) {
276        self.inner.push((lhs, rhs));
277    }
278
279    /// Calls `evaluate` on all constraints. If any constraints are now `T = Integer(val)`,
280    /// those updated values are returned. Such constraints are then removed
281    pub fn update_type_level_value_constraints(
282        self,
283        type_state: &TypeState,
284    ) -> (
285        TypeConstraints,
286        Vec<Loc<(TypeVarID, ConstraintReplacement)>>,
287    ) {
288        let mut new_known = vec![];
289        let remaining = self
290            .inner
291            .into_iter()
292            .filter_map(|(expr, rhs)| {
293                let mut rhs = rhs.clone();
294                rhs.constraint = rhs.constraint.evaluate(type_state);
295
296                match &rhs.constraint {
297                    ConstraintExpr::Integer(val) => {
298                        // ().at_loc(..).map is a somewhat ugly way to wrap an arbitrary type
299                        // in a known Loc. This is done to avoid having to impl WithLocation for
300                        // the unusual tuple used here
301                        let replacement = ConstraintReplacement {
302                            val: KnownType::Integer(val.clone()),
303                            context: rhs.context.clone(),
304                        };
305                        new_known
306                            .push(().at_loc(&rhs).map(|_| (expr.clone(), replacement.clone())));
307
308                        None
309                    }
310                    // NOTE: If we add more branches that look like this, combine it with
311                    // Integer
312                    ConstraintExpr::Bool(val) => {
313                        let replacement = ConstraintReplacement {
314                            val: KnownType::Bool(val.clone()),
315                            context: rhs.context.clone(),
316                        };
317                        new_known
318                            .push(().at_loc(&rhs).map(|_| (expr.clone(), replacement.clone())));
319
320                        None
321                    }
322                    ConstraintExpr::Var(_)
323                    | ConstraintExpr::Sum(_, _)
324                    | ConstraintExpr::Div(_, _)
325                    | ConstraintExpr::Mod(_, _)
326                    | ConstraintExpr::Eq(_, _)
327                    | ConstraintExpr::NotEq(_, _)
328                    | ConstraintExpr::Difference(_, _)
329                    | ConstraintExpr::Product(_, _)
330                    | ConstraintExpr::UintBitsToRepresent(_)
331                    | ConstraintExpr::Sub(_) => Some((expr.clone(), rhs)),
332                }
333            })
334            .collect();
335
336        (TypeConstraints { inner: remaining }, new_known)
337    }
338}
339
340#[derive(Clone, Debug)]
341pub struct ConstraintReplacement {
342    /// The actual constraint
343    pub val: KnownType,
344    pub context: ConstraintContext,
345}
346
347#[derive(Clone, Debug)]
348pub struct ConstraintContext {
349    /// A type var in which this constraint applies. For example, if a constraint
350    /// this constraint constrains `t1` inside `int<t1>`, then `from` is `int<t1>`
351    pub inside: TypeVarID,
352    /// The left hand side which this constrains. Used together with `from` to construct
353    /// type errors
354    pub replaces: TypeVarID,
355    /// Context in which this constraint was added to give hints to the user
356    pub source: ConstraintSource,
357}