1use num::BigInt;
2use serde::{Deserialize, Serialize};
3use spade_common::location_info::{Loc, WithLocation};
4use spade_types::KnownType;
5
6use crate::{
7 equation::{TypeVar, TypeVarID},
8 TypeState,
9};
10
11#[derive(Debug, Clone)]
12pub enum ConstraintExpr {
13 Bool(bool),
14 Integer(BigInt),
15 Var(TypeVarID),
16 Sum(Box<ConstraintExpr>, Box<ConstraintExpr>),
17 Difference(Box<ConstraintExpr>, Box<ConstraintExpr>),
18 Product(Box<ConstraintExpr>, Box<ConstraintExpr>),
19 Div(Box<ConstraintExpr>, Box<ConstraintExpr>),
20 Mod(Box<ConstraintExpr>, Box<ConstraintExpr>),
21 Sub(Box<ConstraintExpr>),
22 Eq(Box<ConstraintExpr>, Box<ConstraintExpr>),
23 NotEq(Box<ConstraintExpr>, Box<ConstraintExpr>),
24 UintBitsToRepresent(Box<ConstraintExpr>),
27}
28impl WithLocation for ConstraintExpr {}
29
30impl ConstraintExpr {
31 pub fn debug_display(&self, type_state: &TypeState) -> String {
32 match self {
33 ConstraintExpr::Bool(b) => format!("{b}"),
34 ConstraintExpr::Integer(v) => format!("{v}"),
35 ConstraintExpr::Var(type_var_id) => {
36 format!("{}", type_var_id.debug_resolve(type_state))
37 }
38 ConstraintExpr::Sum(lhs, rhs) => {
39 format!(
40 "({} + {})",
41 lhs.debug_display(type_state),
42 rhs.debug_display(type_state)
43 )
44 }
45 ConstraintExpr::Difference(lhs, rhs) => {
46 format!(
47 "({} - {})",
48 lhs.debug_display(type_state),
49 rhs.debug_display(type_state)
50 )
51 }
52 ConstraintExpr::Product(lhs, rhs) => {
53 format!(
54 "({} * {})",
55 lhs.debug_display(type_state),
56 rhs.debug_display(type_state)
57 )
58 }
59 ConstraintExpr::Div(lhs, rhs) => {
60 format!(
61 "({} / {})",
62 lhs.debug_display(type_state),
63 rhs.debug_display(type_state)
64 )
65 }
66 ConstraintExpr::Mod(lhs, rhs) => {
67 format!(
68 "({} % {})",
69 lhs.debug_display(type_state),
70 rhs.debug_display(type_state)
71 )
72 }
73 ConstraintExpr::Sub(lhs) => {
74 format!("(-{})", lhs.debug_display(type_state))
75 }
76 ConstraintExpr::Eq(lhs, rhs) => {
77 format!(
78 "({} == {})",
79 lhs.debug_display(type_state),
80 rhs.debug_display(type_state)
81 )
82 }
83 ConstraintExpr::NotEq(lhs, rhs) => {
84 format!(
85 "({} != {})",
86 lhs.debug_display(type_state),
87 rhs.debug_display(type_state)
88 )
89 }
90 ConstraintExpr::UintBitsToRepresent(c) => {
91 format!("uint_bits_to_fit({})", c.debug_display(type_state))
92 }
93 }
94 }
95}
96
97impl ConstraintExpr {
98 fn evaluate(&self, type_state: &TypeState) -> ConstraintExpr {
100 let binop =
101 |lhs: &ConstraintExpr, rhs: &ConstraintExpr, op: &dyn Fn(BigInt, BigInt) -> BigInt| {
102 match (lhs.evaluate(type_state), rhs.evaluate(type_state)) {
103 (ConstraintExpr::Integer(l), ConstraintExpr::Integer(r)) => {
104 ConstraintExpr::Integer(op(l, r))
105 }
106 _ => self.clone(),
107 }
108 };
109 match self {
110 ConstraintExpr::Integer(_) => self.clone(),
111 ConstraintExpr::Bool(_) => self.clone(),
112 ConstraintExpr::Var(v) => match v.resolve(type_state) {
113 TypeVar::Known(_, known_type, _) => match known_type {
114 KnownType::Integer(i) => ConstraintExpr::Integer(i.clone()),
115 KnownType::Bool(b) => ConstraintExpr::Bool(b.clone()),
116 KnownType::Error => self.clone(),
117 KnownType::Named(_)
118 | KnownType::Tuple
119 | KnownType::Array
120 | KnownType::Wire
121 | KnownType::Inverted => {
122 panic!("Inferred non-integer or bool for constraint variable")
123 }
124 },
125 TypeVar::Unknown(_, _, _, _) => self.clone(),
126 },
127 ConstraintExpr::Sum(lhs, rhs) => binop(lhs, rhs, &|l, r| l + r),
128 ConstraintExpr::Difference(lhs, rhs) => binop(lhs, rhs, &|l, r| l - r),
129 ConstraintExpr::Product(lhs, rhs) => binop(lhs, rhs, &|l, r| l * r),
130 ConstraintExpr::Div(lhs, rhs) => binop(lhs, rhs, &|l, r| l / r),
131 ConstraintExpr::Mod(lhs, rhs) => binop(lhs, rhs, &|l, r| l % r),
132 ConstraintExpr::Sub(inner) => match inner.evaluate(type_state) {
133 ConstraintExpr::Integer(val) => ConstraintExpr::Integer(-val),
134 _ => self.clone(),
135 },
136 ConstraintExpr::Eq(lhs, rhs) => {
137 match (lhs.evaluate(type_state), rhs.evaluate(type_state)) {
138 (ConstraintExpr::Integer(l), ConstraintExpr::Integer(r)) => {
139 ConstraintExpr::Bool(l == r)
140 }
141 _ => self.clone(),
142 }
143 }
144 ConstraintExpr::NotEq(lhs, rhs) => {
145 match (lhs.evaluate(type_state), rhs.evaluate(type_state)) {
146 (ConstraintExpr::Integer(l), ConstraintExpr::Integer(r)) => {
147 ConstraintExpr::Bool(l != r)
148 }
149 _ => self.clone(),
150 }
151 }
152 ConstraintExpr::UintBitsToRepresent(inner) => match inner.evaluate(type_state) {
153 ConstraintExpr::Integer(val) => ConstraintExpr::Integer(val.bits().into()),
154 _ => self.clone(),
155 },
156 }
157 }
158
159 pub fn with_context(
160 self,
161 replaces: &TypeVarID,
162 inside: &TypeVarID,
163 source: ConstraintSource,
164 ) -> ConstraintRhs {
165 ConstraintRhs {
166 constraint: self,
167 context: ConstraintContext {
168 replaces: replaces.clone(),
169 inside: inside.clone(),
170 source,
171 },
172 }
173 }
174}
175
176impl std::ops::Add for ConstraintExpr {
177 type Output = ConstraintExpr;
178
179 fn add(self, rhs: Self) -> Self::Output {
180 ConstraintExpr::Sum(Box::new(self), Box::new(rhs))
181 }
182}
183
184impl std::ops::Sub for ConstraintExpr {
185 type Output = ConstraintExpr;
186
187 fn sub(self, rhs: Self) -> Self::Output {
188 ConstraintExpr::Sum(Box::new(self), Box::new(-rhs))
189 }
190}
191
192impl std::ops::Neg for ConstraintExpr {
193 type Output = ConstraintExpr;
194
195 fn neg(self) -> Self::Output {
196 ConstraintExpr::Sub(Box::new(self))
197 }
198}
199
200pub fn bits_to_store(inner: ConstraintExpr) -> ConstraintExpr {
201 ConstraintExpr::UintBitsToRepresent(Box::new(inner))
202}
203
204pub fn ce_var(v: &TypeVarID) -> ConstraintExpr {
206 ConstraintExpr::Var(v.clone())
207}
208pub fn ce_int(v: BigInt) -> ConstraintExpr {
209 ConstraintExpr::Integer(v)
210}
211
212#[derive(Debug, Clone, PartialEq)]
213pub enum ConstraintSource {
214 AdditionOutput,
215 MultOutput,
216 ArrayIndexing,
217 MemoryIndexing,
218 Concatenation,
219 PipelineRegOffset { reg: Loc<()>, total: Loc<()> },
220 PipelineRegCount { reg: Loc<()>, total: Loc<()> },
221 PipelineAvailDepth,
222 RangeIndex,
223 RangeIndexOutputSize,
224 ArraySize,
225 TypeLevelIf,
226 Where,
227}
228
229impl std::fmt::Display for ConstraintSource {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 match self {
232 ConstraintSource::AdditionOutput => write!(f, "AdditionOutput"),
233 ConstraintSource::MultOutput => write!(f, "MultiplicationOutput"),
234 ConstraintSource::ArrayIndexing => write!(f, "ArrayIndexing"),
235 ConstraintSource::MemoryIndexing => write!(f, "MemoryIndexing"),
236 ConstraintSource::Concatenation => write!(f, "Concatenation"),
237 ConstraintSource::Where => write!(f, "Where"),
238 ConstraintSource::RangeIndex => write!(f, "RangeIndex"),
239 ConstraintSource::RangeIndexOutputSize => write!(f, "RangeIndexOutputSize"),
240 ConstraintSource::ArraySize => write!(f, "ArraySize"),
241 ConstraintSource::PipelineRegOffset { .. } => write!(f, "PipelineRegOffset"),
242 ConstraintSource::PipelineRegCount { .. } => write!(f, "PipelineRegOffset"),
243 ConstraintSource::PipelineAvailDepth => write!(f, "PipelineAvailDepth"),
244 ConstraintSource::TypeLevelIf => write!(f, "TypeLevelIf"),
245 }
246 }
247}
248
249#[derive(Debug, Clone)]
250pub struct ConstraintRhs {
251 pub constraint: ConstraintExpr,
253 pub context: ConstraintContext,
254}
255
256impl ConstraintRhs {
257 pub fn debug_display(&self, type_state: &TypeState) -> String {
258 self.constraint.debug_display(type_state)
259 }
260}
261
262impl WithLocation for ConstraintRhs {}
263
264#[derive(Clone, Serialize, Deserialize)]
265pub struct TypeConstraints {
266 #[serde(skip)]
267 pub inner: Vec<(TypeVarID, Loc<ConstraintRhs>)>,
268}
269
270impl TypeConstraints {
271 pub fn new() -> Self {
272 Self { inner: vec![] }
273 }
274
275 pub fn add_int_constraint(&mut self, lhs: TypeVarID, rhs: Loc<ConstraintRhs>) {
276 self.inner.push((lhs, rhs));
277 }
278
279 pub fn update_type_level_value_constraints(
282 self,
283 type_state: &TypeState,
284 ) -> (
285 TypeConstraints,
286 Vec<Loc<(TypeVarID, ConstraintReplacement)>>,
287 ) {
288 let mut new_known = vec![];
289 let remaining = self
290 .inner
291 .into_iter()
292 .filter_map(|(expr, rhs)| {
293 let mut rhs = rhs.clone();
294 rhs.constraint = rhs.constraint.evaluate(type_state);
295
296 match &rhs.constraint {
297 ConstraintExpr::Integer(val) => {
298 let replacement = ConstraintReplacement {
302 val: KnownType::Integer(val.clone()),
303 context: rhs.context.clone(),
304 };
305 new_known
306 .push(().at_loc(&rhs).map(|_| (expr.clone(), replacement.clone())));
307
308 None
309 }
310 ConstraintExpr::Bool(val) => {
313 let replacement = ConstraintReplacement {
314 val: KnownType::Bool(val.clone()),
315 context: rhs.context.clone(),
316 };
317 new_known
318 .push(().at_loc(&rhs).map(|_| (expr.clone(), replacement.clone())));
319
320 None
321 }
322 ConstraintExpr::Var(_)
323 | ConstraintExpr::Sum(_, _)
324 | ConstraintExpr::Div(_, _)
325 | ConstraintExpr::Mod(_, _)
326 | ConstraintExpr::Eq(_, _)
327 | ConstraintExpr::NotEq(_, _)
328 | ConstraintExpr::Difference(_, _)
329 | ConstraintExpr::Product(_, _)
330 | ConstraintExpr::UintBitsToRepresent(_)
331 | ConstraintExpr::Sub(_) => Some((expr.clone(), rhs)),
332 }
333 })
334 .collect();
335
336 (TypeConstraints { inner: remaining }, new_known)
337 }
338}
339
340#[derive(Clone, Debug)]
341pub struct ConstraintReplacement {
342 pub val: KnownType,
344 pub context: ConstraintContext,
345}
346
347#[derive(Clone, Debug)]
348pub struct ConstraintContext {
349 pub inside: TypeVarID,
352 pub replaces: TypeVarID,
355 pub source: ConstraintSource,
357}