spade_hir_lowering/passes/
pass.rs

1use spade_common::location_info::Loc;
2use spade_hir::{Binding, ExprKind, Expression, PipelineRegMarkerExtra, Register, Statement, Unit};
3
4use crate::Result;
5
6pub trait Pass {
7    fn visit_expression(&mut self, expression: &mut Loc<Expression>) -> Result<()>;
8    /// Visit a statement, transforming it into a list of new statements which replace it. If the
9    /// statement should be replaced Ok(Some(new...)) should be returned, if it should be kept,
10    /// Ok(None) should be returned
11    fn visit_statement(
12        &mut self,
13        _statement: &Loc<Statement>,
14    ) -> Result<Option<Vec<Loc<Statement>>>> {
15        Ok(None)
16    }
17    /// Perform transformations on the unit. This should not transform the body of the unit, that
18    /// is handled by `visit_expression`
19    fn visit_unit(&mut self, unit: &mut Unit) -> Result<()>;
20}
21
22pub trait Passable {
23    /// Applies the pass to this HIR node. Children are visited before
24    /// parents. Statements are visited in the order that they are defined
25    fn apply(&mut self, pass: &mut dyn Pass) -> Result<()>;
26}
27
28impl Passable for Loc<Expression> {
29    fn apply(&mut self, pass: &mut dyn Pass) -> Result<()> {
30        macro_rules! subnodes {
31            ($($node:expr),*) => {
32                {$($node.apply(pass)?;)*}
33            };
34        }
35
36        match &mut self.inner.kind {
37            ExprKind::Error => {}
38            ExprKind::Identifier(_) => {}
39            ExprKind::IntLiteral(_, _) => {}
40            ExprKind::TypeLevelInteger(_) => {}
41            ExprKind::BoolLiteral(_) => {}
42            ExprKind::BitLiteral(_) => {}
43            ExprKind::CreatePorts => {}
44            ExprKind::StageReady | ExprKind::StageValid => {}
45            ExprKind::TupleLiteral(inner) => {
46                for i in inner {
47                    i.apply(pass)?
48                }
49            }
50            ExprKind::ArrayLiteral(inner) => {
51                for i in inner {
52                    i.apply(pass)?
53                }
54            }
55            ExprKind::ArrayShorthandLiteral(inner, _) => {
56                inner.apply(pass)?;
57            }
58            ExprKind::Index(lhs, rhs) => {
59                subnodes!(lhs, rhs)
60            }
61            ExprKind::RangeIndex {
62                target,
63                start: _,
64                end: _,
65            } => {
66                subnodes!(target)
67            }
68            ExprKind::TupleIndex(lhs, _) => subnodes!(lhs),
69            ExprKind::FieldAccess(lhs, _) => subnodes!(lhs),
70            ExprKind::MethodCall {
71                target: self_,
72                name: _,
73                args,
74                call_kind: _,
75                turbofish: _,
76            } => {
77                subnodes!(self_);
78                for arg in args.expressions_mut() {
79                    arg.apply(pass)?;
80                }
81            }
82            ExprKind::Call {
83                kind: _,
84                callee: _,
85                args,
86                turbofish: _,
87            } => {
88                for arg in args.expressions_mut() {
89                    arg.apply(pass)?;
90                }
91            }
92            ExprKind::BinaryOperator(lhs, _, rhs) => subnodes!(lhs, rhs),
93            ExprKind::UnaryOperator(_, operand) => subnodes!(operand),
94            ExprKind::Match(cond, branches) => {
95                cond.apply(pass)?;
96                for (_, branch) in branches {
97                    branch.apply(pass)?;
98                }
99            }
100            ExprKind::Block(block) => {
101                block.statements = block
102                    .statements
103                    .iter()
104                    .map(|stmt| match pass.visit_statement(stmt)? {
105                        Some(new) => Ok(new),
106                        None => Ok(vec![stmt.clone()]),
107                    })
108                    .collect::<Result<Vec<_>>>()?
109                    .into_iter()
110                    .flatten()
111                    .collect();
112
113                for statement in &mut block.statements {
114                    match &mut statement.inner {
115                        Statement::Error => {}
116                        Statement::Binding(Binding {
117                            pattern: _,
118                            ty: _,
119                            value,
120                            wal_trace: _,
121                        }) => value.apply(pass)?,
122                        Statement::Expression(expr) => expr.apply(pass)?,
123                        Statement::Register(reg) => {
124                            let Register {
125                                pattern: _,
126                                clock,
127                                reset,
128                                initial,
129                                value,
130                                value_type: _,
131                                attributes: _,
132                            } = reg;
133
134                            match reset {
135                                Some((trig, val)) => subnodes!(trig, val),
136                                None => {}
137                            }
138
139                            match initial {
140                                Some(initial) => subnodes!(initial),
141                                None => {}
142                            }
143
144                            subnodes!(clock, value);
145                        }
146                        Statement::Declaration(_) => {}
147                        Statement::PipelineRegMarker(extra) => match extra {
148                            Some(PipelineRegMarkerExtra::Condition(cond)) => {
149                                cond.apply(pass)?;
150                            }
151                            Some(PipelineRegMarkerExtra::Count {
152                                count: _,
153                                count_typeexpr_id: _,
154                            }) => {}
155                            None => {}
156                        },
157                        Statement::Label(_) => {}
158                        Statement::WalSuffixed {
159                            suffix: _,
160                            target: _,
161                        } => {}
162                        Statement::Assert(expr) => expr.apply(pass)?,
163                        Statement::Set { target, value } => subnodes!(target, value),
164                    }
165                }
166
167                if let Some(result) = &mut block.result {
168                    result.apply(pass)?;
169                }
170            }
171            ExprKind::LambdaDef {
172                arguments: _,
173                body,
174                lambda_type: _,
175                lambda_type_params: _,
176                captured_generic_params: _,
177                lambda_unit: _,
178            } => {
179                subnodes!(body)
180            }
181            ExprKind::If(cond, on_true, on_false) => subnodes!(cond, on_true, on_false),
182            ExprKind::TypeLevelIf(_cond, on_true, on_false) => subnodes!(on_true, on_false),
183            ExprKind::PipelineRef {
184                stage: _,
185                name: _,
186                declares_name: _,
187                depth_typeexpr_id: _,
188            } => {}
189            ExprKind::Null | ExprKind::StaticUnreachable(_) => {}
190        };
191
192        pass.visit_expression(self)
193    }
194}
195
196impl Passable for Unit {
197    fn apply(&mut self, pass: &mut dyn Pass) -> Result<()> {
198        pass.visit_unit(self)?;
199        self.body.apply(pass)?;
200        Ok(())
201    }
202}