spade_hir_lowering/linear_check/
mod.rs

1use num::ToPrimitive;
2use tracing::trace;
3
4use spade_common::{
5    location_info::{Loc, WithLocation},
6    name::{NameID, Path},
7};
8use spade_diagnostics::diagnostic::Subdiagnostic;
9use spade_diagnostics::{diag_bail, Diagnostic};
10use spade_hir::{
11    expression::{NamedArgument, UnaryOperator},
12    symbol_table::SymbolTable,
13    ArgumentList, Binding, ExprKind, Expression, PipelineRegMarkerExtra, Register, Statement,
14    TypeList, TypeSpec,
15};
16use spade_typeinference::TypeState;
17
18use self::linear_state::{is_linear, LinearState};
19use crate::error::Result;
20
21mod linear_state;
22
23pub struct LinearCtx<'a> {
24    pub type_state: &'a TypeState,
25    pub symtab: &'a SymbolTable,
26    pub types: &'a TypeList,
27}
28
29/// Checks for linear type errors in a function-like. Reports errors if an linear
30/// type is not used exactly once
31pub fn check_linear_types(
32    inputs: &[(Loc<NameID>, Loc<TypeSpec>)],
33    body: &Loc<Expression>,
34    type_state: &TypeState,
35    symtab: &SymbolTable,
36    types: &TypeList,
37) -> Result<()> {
38    let ctx = LinearCtx {
39        types,
40        symtab,
41        type_state,
42    };
43
44    let mut linear_state = LinearState::new();
45
46    for (name, _) in inputs {
47        linear_state.push_new_name(name, &ctx)
48    }
49
50    visit_expression(body, &mut linear_state, &ctx)?;
51
52    linear_state.consume_expression(body)?;
53
54    linear_state.check_unused().map_err(|(alias, witness)| {
55        let self_description = match &alias.inner {
56            linear_state::ItemReference::Name(n) => format!("{n}{}", witness.motivation()),
57            linear_state::ItemReference::Anonymous(_) => {
58                format!("This has a field {} that", witness.motivation())
59            }
60        };
61        Diagnostic::error(&alias, format!("{self_description} is unused"))
62            .primary_label(format!("{self_description} is unused"))
63            .note(format!(
64                "{self_description} is an inverted wire (`inv &`) which must be set"
65            ))
66    })?;
67
68    Ok(())
69}
70
71pub fn visit_statement(
72    stmt: &Loc<Statement>,
73    linear_state: &mut LinearState,
74    ctx: &LinearCtx,
75) -> Result<()> {
76    match &stmt.inner {
77        Statement::Error => {}
78        Statement::Binding(Binding {
79            pattern,
80            ty: _,
81            value,
82            wal_trace: _,
83        }) => {
84            visit_expression(value, linear_state, ctx)?;
85            linear_state.consume_expression(value)?;
86            linear_state.push_pattern(pattern, ctx)?
87        }
88        Statement::Expression(expr) => {
89            visit_expression(expr, linear_state, ctx)?;
90        }
91        Statement::Register(reg) => {
92            let Register {
93                pattern,
94                clock,
95                reset,
96                initial,
97                value,
98                value_type: _,
99                attributes: _,
100            } = &reg;
101
102            linear_state.push_pattern(pattern, ctx)?;
103
104            visit_expression(clock, linear_state, ctx)?;
105            if let Some((trig, val)) = &reset {
106                visit_expression(trig, linear_state, ctx)?;
107                visit_expression(val, linear_state, ctx)?;
108            }
109            initial
110                .as_ref()
111                .map(|i| visit_expression(i, linear_state, ctx))
112                .transpose()?;
113
114            visit_expression(value, linear_state, ctx)?;
115
116            linear_state.consume_expression(value)?;
117        }
118        Statement::Declaration(names) => {
119            for name in names {
120                linear_state.push_new_name(name, ctx)
121            }
122        }
123        Statement::PipelineRegMarker(cond) => match cond {
124            Some(PipelineRegMarkerExtra::Count {
125                count: _,
126                count_typeexpr_id: _,
127            }) => {}
128            Some(PipelineRegMarkerExtra::Condition(cond)) => {
129                visit_expression(cond, linear_state, ctx)?;
130            }
131            None => {}
132        },
133        Statement::Label(_) => {}
134        Statement::Assert(_) => {}
135        Statement::WalSuffixed { .. } => {}
136        Statement::Set { target, value } => {
137            visit_expression(target, linear_state, ctx)?;
138            visit_expression(value, linear_state, ctx)?;
139            linear_state.consume_expression(target)?;
140            linear_state.consume_expression(value)?;
141        }
142    }
143    Ok(())
144}
145
146#[tracing::instrument(level = "trace", skip_all)]
147fn visit_expression(
148    expr: &Loc<Expression>,
149    linear_state: &mut LinearState,
150    ctx: &LinearCtx,
151) -> Result<()> {
152    let produces_new_resource = match &expr.kind {
153        spade_hir::ExprKind::Error => true,
154        spade_hir::ExprKind::Identifier(_) => false,
155        spade_hir::ExprKind::IntLiteral(_, _) => true,
156        spade_hir::ExprKind::TypeLevelInteger(_) => true,
157        spade_hir::ExprKind::BoolLiteral(_) => true,
158        spade_hir::ExprKind::BitLiteral(_) => true,
159        spade_hir::ExprKind::TupleLiteral(_) => true,
160        spade_hir::ExprKind::ArrayLiteral(_) => true,
161        spade_hir::ExprKind::ArrayShorthandLiteral(_, _) => true,
162        spade_hir::ExprKind::CreatePorts => true,
163        spade_hir::ExprKind::Index(_, _) => true,
164        spade_hir::ExprKind::RangeIndex { .. } => true,
165        spade_hir::ExprKind::TupleIndex(_, _) => false,
166        spade_hir::ExprKind::FieldAccess(_, _) => false,
167        spade_hir::ExprKind::BinaryOperator(_, _, _) => true,
168        spade_hir::ExprKind::UnaryOperator(_, _) => true,
169        spade_hir::ExprKind::Match(_, _) => true,
170        spade_hir::ExprKind::Block(_) => true,
171        spade_hir::ExprKind::Call { .. } => true,
172        spade_hir::ExprKind::If(_, _, _) => true,
173        spade_hir::ExprKind::TypeLevelIf(_, _, _) => true,
174        spade_hir::ExprKind::StageValid | spade_hir::ExprKind::StageReady => true,
175        spade_hir::ExprKind::PipelineRef {
176            stage: _,
177            name: _,
178            declares_name: _,
179            depth_typeexpr_id: _,
180        } => false,
181        spade_hir::ExprKind::LambdaDef { .. } => diag_bail!(
182            expr,
183            "Lambda def should have been lowered to function by this point"
184        ),
185        spade_hir::ExprKind::MethodCall { .. } => diag_bail!(
186            expr,
187            "method call should have been lowered to function by this point"
188        ),
189        spade_hir::ExprKind::Null | ExprKind::StaticUnreachable(_) => false,
190    };
191
192    if produces_new_resource {
193        trace!("Pushing expression {}", expr.id.0);
194        linear_state.push_new_expression(&expr.map_ref(|e| e.id), ctx);
195    }
196
197    match &expr.kind {
198        spade_hir::ExprKind::Error => {}
199        spade_hir::ExprKind::Identifier(name) => {
200            linear_state.add_alias_name(expr.id.at_loc(expr), &name.clone().at_loc(expr))?
201        }
202        spade_hir::ExprKind::IntLiteral(_, _) => {}
203        spade_hir::ExprKind::TypeLevelInteger(_) => {}
204        spade_hir::ExprKind::BoolLiteral(_) => {}
205        spade_hir::ExprKind::BitLiteral(_) => {}
206        spade_hir::ExprKind::StageValid | spade_hir::ExprKind::StageReady => {}
207        spade_hir::ExprKind::TupleLiteral(inner) => {
208            for (i, expr) in inner.iter().enumerate() {
209                visit_expression(expr, linear_state, ctx)?;
210                trace!("visited tuple literal member {i}");
211                linear_state.consume_expression(expr)?;
212            }
213        }
214        spade_hir::ExprKind::ArrayLiteral(inner) => {
215            for expr in inner {
216                visit_expression(expr, linear_state, ctx)?;
217                trace!("Consuming array literal inner");
218                linear_state.consume_expression(expr)?;
219            }
220        }
221        spade_hir::ExprKind::ArrayShorthandLiteral(inner, _) => {
222            visit_expression(inner, linear_state, ctx)?;
223            // FIXME: should allow `[instance of ~&T; 0]` and `[instance of ~&T; 1]` here
224            // try to consume twice. if we get an error, add a note
225            linear_state.consume_expression(inner)?;
226            if let Err(mut diag) = linear_state.consume_expression(inner) {
227                diag.push_subdiagnostic(Subdiagnostic::span_note(
228                    expr,
229                    "The resource is used in this array initialization",
230                ));
231                return Err(diag);
232            }
233        }
234        spade_hir::ExprKind::CreatePorts => {}
235        spade_hir::ExprKind::Index(target, idx_expr) => {
236            visit_expression(target, linear_state, ctx)?;
237            visit_expression(idx_expr, linear_state, ctx)?;
238
239            if is_linear(
240                &ctx.type_state
241                    .concrete_type_of(target, ctx.symtab, ctx.types)?,
242            ) {
243                let idx = match &idx_expr.kind {
244                    ExprKind::IntLiteral(value, _) => value,
245                    _ => {
246                        return Err(Diagnostic::error(
247                            expr,
248                            "Array with mutable wires cannot be indexed by non-constant values",
249                        )
250                        .primary_label("Array with mutable wires indexed by non-constant")
251                        .secondary_label(idx_expr.loc(), "Expected constant"))
252                    }
253                };
254
255                let idx = idx.to_u128().ok_or_else(|| {
256                    Diagnostic::error(
257                        target.loc(),
258                        "Array indices > 2^64 are not allowed on mutable wires",
259                    )
260                })?;
261
262                // If the array has mutable wires, we need to guarantee statically that they are
263                // used exactly once. To do that, we need to ensure that the array is indexed by a
264                // statically known index. However, this check is only required if the array actually
265                // has linear type
266                linear_state.alias_array_member(
267                    expr.id.at_loc(expr),
268                    target.id,
269                    &idx.at_loc(idx_expr),
270                )?;
271            } else {
272                linear_state.consume_expression(target)?;
273            }
274
275            linear_state.consume_expression(idx_expr)?;
276        }
277        spade_hir::ExprKind::RangeIndex {
278            target,
279            start: _,
280            end: _,
281        } => {
282            visit_expression(target, linear_state, ctx)?;
283            // We don't track individual elements of arrays, so we'll have to consume the
284            // whole thing here
285            linear_state.consume_expression(target)?;
286        }
287        spade_hir::ExprKind::TupleIndex(base, idx) => {
288            visit_expression(base, linear_state, ctx)?;
289            linear_state.alias_tuple_member(expr.id.at_loc(expr), base.id, idx)?
290        }
291        spade_hir::ExprKind::FieldAccess(base, field) => {
292            visit_expression(base, linear_state, ctx)?;
293            linear_state.alias_struct_member(expr.id.at_loc(expr), base.id, field)?
294        }
295        spade_hir::ExprKind::BinaryOperator(lhs, _, rhs) => {
296            visit_expression(lhs, linear_state, ctx)?;
297            visit_expression(rhs, linear_state, ctx)?;
298            linear_state.consume_expression(lhs)?;
299            linear_state.consume_expression(rhs)?;
300        }
301        spade_hir::ExprKind::UnaryOperator(op, operand) => {
302            visit_expression(operand, linear_state, ctx)?;
303            match op.inner {
304                UnaryOperator::Sub
305                | UnaryOperator::Not
306                | UnaryOperator::BitwiseNot
307                | UnaryOperator::Reference => {
308                    linear_state.consume_expression(operand)?;
309                }
310                UnaryOperator::Dereference => {}
311            }
312        }
313        spade_hir::ExprKind::Match(cond, variants) => {
314            visit_expression(cond, linear_state, ctx)?;
315            for (pat, expr) in variants {
316                linear_state.push_pattern(pat, ctx)?;
317                visit_expression(expr, linear_state, ctx)?;
318            }
319        }
320        spade_hir::ExprKind::Block(b) => {
321            for statement in &b.statements {
322                visit_statement(statement, linear_state, ctx)?;
323            }
324            if let Some(result) = &b.result {
325                visit_expression(result, linear_state, ctx)?;
326                trace!("Consuming block {}", expr.id.0);
327                linear_state.consume_expression(result)?;
328            }
329        }
330        spade_hir::ExprKind::Call {
331            kind: _,
332            callee,
333            args: list,
334            turbofish: _,
335        } => {
336            // The read_mut_wire function is special and should not consume the port
337            // it is reading.
338            // FIXME: When spade is more generic and can handle the * operator
339            // doing more fancy things, we should consider getting rid of this function
340            let consume = ctx
341                .symtab
342                .try_lookup_final_id(
343                    &Path::from_strs(&["std", "ports", "read_mut_wire"]).nowhere(),
344                    &[],
345                )
346                .map(|n| n != callee.inner)
347                .unwrap_or(true);
348
349            match &list.inner {
350                ArgumentList::Named(args) => {
351                    for arg in args {
352                        match arg {
353                            NamedArgument::Full(_, expr) | NamedArgument::Short(_, expr) => {
354                                visit_expression(expr, linear_state, ctx)?;
355                                if consume {
356                                    linear_state.consume_expression(expr)?;
357                                }
358                            }
359                        }
360                    }
361                }
362                ArgumentList::Positional(args) => {
363                    for arg in args {
364                        visit_expression(arg, linear_state, ctx)?;
365                        if consume {
366                            linear_state.consume_expression(arg)?;
367                        }
368                    }
369                }
370            }
371        }
372        spade_hir::ExprKind::If(cond, on_true, on_false) => {
373            visit_expression(cond, linear_state, ctx)?;
374            visit_expression(on_true, linear_state, ctx)?;
375            visit_expression(on_false, linear_state, ctx)?;
376        }
377        spade_hir::ExprKind::PipelineRef {
378            stage: _,
379            name,
380            declares_name,
381            depth_typeexpr_id: _,
382        } => {
383            if *declares_name {
384                linear_state.push_new_name(name, ctx);
385            }
386            linear_state.add_alias_name(expr.id.at_loc(expr), &name.clone())?
387        }
388        spade_hir::ExprKind::TypeLevelIf(_, _, _) => {
389            diag_bail!(expr, "Type level if should have been lowered")
390        }
391        spade_hir::ExprKind::MethodCall { .. } => diag_bail!(
392            expr,
393            "method call should have been lowered to function by this point"
394        ),
395        spade_hir::ExprKind::LambdaDef { .. } => diag_bail!(
396            expr,
397            "lambda def should have been lowered to function by this point"
398        ),
399        spade_hir::ExprKind::Null { .. } => {
400            diag_bail!(expr, "Null expression created before linear check")
401        }
402        spade_hir::ExprKind::StaticUnreachable(_) => {}
403    }
404    Ok(())
405}