1use num::ToPrimitive;
2use tracing::trace;
3
4use spade_common::{
5 location_info::{Loc, WithLocation},
6 name::{NameID, Path},
7};
8use spade_diagnostics::diagnostic::Subdiagnostic;
9use spade_diagnostics::{diag_bail, Diagnostic};
10use spade_hir::{
11 expression::{NamedArgument, UnaryOperator},
12 symbol_table::SymbolTable,
13 ArgumentList, Binding, ExprKind, Expression, PipelineRegMarkerExtra, Register, Statement,
14 TypeList, TypeSpec,
15};
16use spade_typeinference::TypeState;
17
18use self::linear_state::{is_linear, LinearState};
19use crate::error::Result;
20
21mod linear_state;
22
23pub struct LinearCtx<'a> {
24 pub type_state: &'a TypeState,
25 pub symtab: &'a SymbolTable,
26 pub types: &'a TypeList,
27}
28
29pub fn check_linear_types(
32 inputs: &[(Loc<NameID>, Loc<TypeSpec>)],
33 body: &Loc<Expression>,
34 type_state: &TypeState,
35 symtab: &SymbolTable,
36 types: &TypeList,
37) -> Result<()> {
38 let ctx = LinearCtx {
39 types,
40 symtab,
41 type_state,
42 };
43
44 let mut linear_state = LinearState::new();
45
46 for (name, _) in inputs {
47 linear_state.push_new_name(name, &ctx)
48 }
49
50 visit_expression(body, &mut linear_state, &ctx)?;
51
52 linear_state.consume_expression(body)?;
53
54 linear_state.check_unused().map_err(|(alias, witness)| {
55 let self_description = match &alias.inner {
56 linear_state::ItemReference::Name(n) => format!("{n}{}", witness.motivation()),
57 linear_state::ItemReference::Anonymous(_) => {
58 format!("This has a field {} that", witness.motivation())
59 }
60 };
61 Diagnostic::error(&alias, format!("{self_description} is unused"))
62 .primary_label(format!("{self_description} is unused"))
63 .note(format!(
64 "{self_description} is an inverted wire (`inv &`) which must be set"
65 ))
66 })?;
67
68 Ok(())
69}
70
71pub fn visit_statement(
72 stmt: &Loc<Statement>,
73 linear_state: &mut LinearState,
74 ctx: &LinearCtx,
75) -> Result<()> {
76 match &stmt.inner {
77 Statement::Error => {}
78 Statement::Binding(Binding {
79 pattern,
80 ty: _,
81 value,
82 wal_trace: _,
83 }) => {
84 visit_expression(value, linear_state, ctx)?;
85 linear_state.consume_expression(value)?;
86 linear_state.push_pattern(pattern, ctx)?
87 }
88 Statement::Expression(expr) => {
89 visit_expression(expr, linear_state, ctx)?;
90 }
91 Statement::Register(reg) => {
92 let Register {
93 pattern,
94 clock,
95 reset,
96 initial,
97 value,
98 value_type: _,
99 attributes: _,
100 } = ®
101
102 linear_state.push_pattern(pattern, ctx)?;
103
104 visit_expression(clock, linear_state, ctx)?;
105 if let Some((trig, val)) = &reset {
106 visit_expression(trig, linear_state, ctx)?;
107 visit_expression(val, linear_state, ctx)?;
108 }
109 initial
110 .as_ref()
111 .map(|i| visit_expression(i, linear_state, ctx))
112 .transpose()?;
113
114 visit_expression(value, linear_state, ctx)?;
115
116 linear_state.consume_expression(value)?;
117 }
118 Statement::Declaration(names) => {
119 for name in names {
120 linear_state.push_new_name(name, ctx)
121 }
122 }
123 Statement::PipelineRegMarker(cond) => match cond {
124 Some(PipelineRegMarkerExtra::Count {
125 count: _,
126 count_typeexpr_id: _,
127 }) => {}
128 Some(PipelineRegMarkerExtra::Condition(cond)) => {
129 visit_expression(cond, linear_state, ctx)?;
130 }
131 None => {}
132 },
133 Statement::Label(_) => {}
134 Statement::Assert(_) => {}
135 Statement::WalSuffixed { .. } => {}
136 Statement::Set { target, value } => {
137 visit_expression(target, linear_state, ctx)?;
138 visit_expression(value, linear_state, ctx)?;
139 linear_state.consume_expression(target)?;
140 linear_state.consume_expression(value)?;
141 }
142 }
143 Ok(())
144}
145
146#[tracing::instrument(level = "trace", skip_all)]
147fn visit_expression(
148 expr: &Loc<Expression>,
149 linear_state: &mut LinearState,
150 ctx: &LinearCtx,
151) -> Result<()> {
152 let produces_new_resource = match &expr.kind {
153 spade_hir::ExprKind::Error => true,
154 spade_hir::ExprKind::Identifier(_) => false,
155 spade_hir::ExprKind::IntLiteral(_, _) => true,
156 spade_hir::ExprKind::TypeLevelInteger(_) => true,
157 spade_hir::ExprKind::BoolLiteral(_) => true,
158 spade_hir::ExprKind::BitLiteral(_) => true,
159 spade_hir::ExprKind::TupleLiteral(_) => true,
160 spade_hir::ExprKind::ArrayLiteral(_) => true,
161 spade_hir::ExprKind::ArrayShorthandLiteral(_, _) => true,
162 spade_hir::ExprKind::CreatePorts => true,
163 spade_hir::ExprKind::Index(_, _) => true,
164 spade_hir::ExprKind::RangeIndex { .. } => true,
165 spade_hir::ExprKind::TupleIndex(_, _) => false,
166 spade_hir::ExprKind::FieldAccess(_, _) => false,
167 spade_hir::ExprKind::BinaryOperator(_, _, _) => true,
168 spade_hir::ExprKind::UnaryOperator(_, _) => true,
169 spade_hir::ExprKind::Match(_, _) => true,
170 spade_hir::ExprKind::Block(_) => true,
171 spade_hir::ExprKind::Call { .. } => true,
172 spade_hir::ExprKind::If(_, _, _) => true,
173 spade_hir::ExprKind::TypeLevelIf(_, _, _) => true,
174 spade_hir::ExprKind::StageValid | spade_hir::ExprKind::StageReady => true,
175 spade_hir::ExprKind::PipelineRef {
176 stage: _,
177 name: _,
178 declares_name: _,
179 depth_typeexpr_id: _,
180 } => false,
181 spade_hir::ExprKind::LambdaDef { .. } => diag_bail!(
182 expr,
183 "Lambda def should have been lowered to function by this point"
184 ),
185 spade_hir::ExprKind::MethodCall { .. } => diag_bail!(
186 expr,
187 "method call should have been lowered to function by this point"
188 ),
189 spade_hir::ExprKind::Null | ExprKind::StaticUnreachable(_) => false,
190 };
191
192 if produces_new_resource {
193 trace!("Pushing expression {}", expr.id.0);
194 linear_state.push_new_expression(&expr.map_ref(|e| e.id), ctx);
195 }
196
197 match &expr.kind {
198 spade_hir::ExprKind::Error => {}
199 spade_hir::ExprKind::Identifier(name) => {
200 linear_state.add_alias_name(expr.id.at_loc(expr), &name.clone().at_loc(expr))?
201 }
202 spade_hir::ExprKind::IntLiteral(_, _) => {}
203 spade_hir::ExprKind::TypeLevelInteger(_) => {}
204 spade_hir::ExprKind::BoolLiteral(_) => {}
205 spade_hir::ExprKind::BitLiteral(_) => {}
206 spade_hir::ExprKind::StageValid | spade_hir::ExprKind::StageReady => {}
207 spade_hir::ExprKind::TupleLiteral(inner) => {
208 for (i, expr) in inner.iter().enumerate() {
209 visit_expression(expr, linear_state, ctx)?;
210 trace!("visited tuple literal member {i}");
211 linear_state.consume_expression(expr)?;
212 }
213 }
214 spade_hir::ExprKind::ArrayLiteral(inner) => {
215 for expr in inner {
216 visit_expression(expr, linear_state, ctx)?;
217 trace!("Consuming array literal inner");
218 linear_state.consume_expression(expr)?;
219 }
220 }
221 spade_hir::ExprKind::ArrayShorthandLiteral(inner, _) => {
222 visit_expression(inner, linear_state, ctx)?;
223 linear_state.consume_expression(inner)?;
226 if let Err(mut diag) = linear_state.consume_expression(inner) {
227 diag.push_subdiagnostic(Subdiagnostic::span_note(
228 expr,
229 "The resource is used in this array initialization",
230 ));
231 return Err(diag);
232 }
233 }
234 spade_hir::ExprKind::CreatePorts => {}
235 spade_hir::ExprKind::Index(target, idx_expr) => {
236 visit_expression(target, linear_state, ctx)?;
237 visit_expression(idx_expr, linear_state, ctx)?;
238
239 if is_linear(
240 &ctx.type_state
241 .concrete_type_of(target, ctx.symtab, ctx.types)?,
242 ) {
243 let idx = match &idx_expr.kind {
244 ExprKind::IntLiteral(value, _) => value,
245 _ => {
246 return Err(Diagnostic::error(
247 expr,
248 "Array with mutable wires cannot be indexed by non-constant values",
249 )
250 .primary_label("Array with mutable wires indexed by non-constant")
251 .secondary_label(idx_expr.loc(), "Expected constant"))
252 }
253 };
254
255 let idx = idx.to_u128().ok_or_else(|| {
256 Diagnostic::error(
257 target.loc(),
258 "Array indices > 2^64 are not allowed on mutable wires",
259 )
260 })?;
261
262 linear_state.alias_array_member(
267 expr.id.at_loc(expr),
268 target.id,
269 &idx.at_loc(idx_expr),
270 )?;
271 } else {
272 linear_state.consume_expression(target)?;
273 }
274
275 linear_state.consume_expression(idx_expr)?;
276 }
277 spade_hir::ExprKind::RangeIndex {
278 target,
279 start: _,
280 end: _,
281 } => {
282 visit_expression(target, linear_state, ctx)?;
283 linear_state.consume_expression(target)?;
286 }
287 spade_hir::ExprKind::TupleIndex(base, idx) => {
288 visit_expression(base, linear_state, ctx)?;
289 linear_state.alias_tuple_member(expr.id.at_loc(expr), base.id, idx)?
290 }
291 spade_hir::ExprKind::FieldAccess(base, field) => {
292 visit_expression(base, linear_state, ctx)?;
293 linear_state.alias_struct_member(expr.id.at_loc(expr), base.id, field)?
294 }
295 spade_hir::ExprKind::BinaryOperator(lhs, _, rhs) => {
296 visit_expression(lhs, linear_state, ctx)?;
297 visit_expression(rhs, linear_state, ctx)?;
298 linear_state.consume_expression(lhs)?;
299 linear_state.consume_expression(rhs)?;
300 }
301 spade_hir::ExprKind::UnaryOperator(op, operand) => {
302 visit_expression(operand, linear_state, ctx)?;
303 match op.inner {
304 UnaryOperator::Sub
305 | UnaryOperator::Not
306 | UnaryOperator::BitwiseNot
307 | UnaryOperator::Reference => {
308 linear_state.consume_expression(operand)?;
309 }
310 UnaryOperator::Dereference => {}
311 }
312 }
313 spade_hir::ExprKind::Match(cond, variants) => {
314 visit_expression(cond, linear_state, ctx)?;
315 for (pat, expr) in variants {
316 linear_state.push_pattern(pat, ctx)?;
317 visit_expression(expr, linear_state, ctx)?;
318 }
319 }
320 spade_hir::ExprKind::Block(b) => {
321 for statement in &b.statements {
322 visit_statement(statement, linear_state, ctx)?;
323 }
324 if let Some(result) = &b.result {
325 visit_expression(result, linear_state, ctx)?;
326 trace!("Consuming block {}", expr.id.0);
327 linear_state.consume_expression(result)?;
328 }
329 }
330 spade_hir::ExprKind::Call {
331 kind: _,
332 callee,
333 args: list,
334 turbofish: _,
335 } => {
336 let consume = ctx
341 .symtab
342 .try_lookup_final_id(
343 &Path::from_strs(&["std", "ports", "read_mut_wire"]).nowhere(),
344 &[],
345 )
346 .map(|n| n != callee.inner)
347 .unwrap_or(true);
348
349 match &list.inner {
350 ArgumentList::Named(args) => {
351 for arg in args {
352 match arg {
353 NamedArgument::Full(_, expr) | NamedArgument::Short(_, expr) => {
354 visit_expression(expr, linear_state, ctx)?;
355 if consume {
356 linear_state.consume_expression(expr)?;
357 }
358 }
359 }
360 }
361 }
362 ArgumentList::Positional(args) => {
363 for arg in args {
364 visit_expression(arg, linear_state, ctx)?;
365 if consume {
366 linear_state.consume_expression(arg)?;
367 }
368 }
369 }
370 }
371 }
372 spade_hir::ExprKind::If(cond, on_true, on_false) => {
373 visit_expression(cond, linear_state, ctx)?;
374 visit_expression(on_true, linear_state, ctx)?;
375 visit_expression(on_false, linear_state, ctx)?;
376 }
377 spade_hir::ExprKind::PipelineRef {
378 stage: _,
379 name,
380 declares_name,
381 depth_typeexpr_id: _,
382 } => {
383 if *declares_name {
384 linear_state.push_new_name(name, ctx);
385 }
386 linear_state.add_alias_name(expr.id.at_loc(expr), &name.clone())?
387 }
388 spade_hir::ExprKind::TypeLevelIf(_, _, _) => {
389 diag_bail!(expr, "Type level if should have been lowered")
390 }
391 spade_hir::ExprKind::MethodCall { .. } => diag_bail!(
392 expr,
393 "method call should have been lowered to function by this point"
394 ),
395 spade_hir::ExprKind::LambdaDef { .. } => diag_bail!(
396 expr,
397 "lambda def should have been lowered to function by this point"
398 ),
399 spade_hir::ExprKind::Null { .. } => {
400 diag_bail!(expr, "Null expression created before linear check")
401 }
402 spade_hir::ExprKind::StaticUnreachable(_) => {}
403 }
404 Ok(())
405}