spade_ast_lowering/
type_level_if.rs

1use spade_common::location_info::Loc;
2use spade_common::location_info::WithLocation;
3use spade_common::name::Identifier;
4use spade_diagnostics::diag_bail;
5use spade_diagnostics::Diagnostic;
6use spade_hir::expression::CallKind;
7use spade_hir::symbol_table::Thing;
8use spade_hir::ArgumentList;
9use spade_hir::Binding;
10use spade_hir::Block;
11use spade_hir::ExecutableItem;
12use spade_hir::Expression;
13use spade_hir::PatternKind;
14use spade_hir::Statement;
15use spade_hir::TypeExpression;
16use spade_hir::UnitKind;
17use spade_hir::UnitName;
18use spade_hir::{ExprKind, Unit};
19
20use crate::Context;
21use crate::Result;
22
23// For pipelining reasons, if we have a unit like
24// {
25//     reg;
26//     gen if ... {
27//         result1
28//     } else {
29//         result2
30//     }
31// }
32// we want to convert it into
33// {
34//     gen if ... {
35//         reg;
36//         result1
37//     } else {
38//         reg;
39//         result2
40//     }
41// }
42// This performs that replacement
43pub fn absorb_statements(
44    body: &Loc<Expression>,
45    outer_statements: &Vec<Loc<Statement>>,
46    ctx: &mut Context,
47) -> Result<Loc<Expression>> {
48    body.try_map_ref(|expr| match &expr.kind {
49        ExprKind::TypeLevelIf(cond, on_true, on_false) => Ok(ExprKind::TypeLevelIf(
50            cond.clone(),
51            Box::new(absorb_statements(on_true, outer_statements, ctx)?),
52            Box::new(absorb_statements(on_false, outer_statements, ctx)?),
53        )
54        .with_id(ctx.idtracker.next())),
55        ExprKind::Block(block) => Ok(ExprKind::Block(Box::new(Block {
56            statements: outer_statements
57                .iter()
58                .chain(block.statements.iter())
59                .cloned()
60                .collect(),
61            result: block.result.clone(),
62        }))
63        .with_id(ctx.idtracker.next())),
64        ExprKind::Error => Ok(ExprKind::Error.with_id(ctx.idtracker.next())),
65        _ => Err(Diagnostic::bug(
66            body,
67            "The body of a gen if can only be a block or another gen if",
68        )
69        .primary_label(format!("Invalid body of gen if"))),
70    })
71}
72
73pub fn expand_type_level_if(mut unit: Loc<Unit>, ctx: &mut Context) -> Result<Loc<Unit>> {
74    let Ok(body) = unit.body.assume_block() else {
75        unit.body.kind = ExprKind::Error;
76        return Ok(unit);
77    };
78
79    let expand_body =
80        |new_body: &Loc<Expression>, name_suffix: &str, ctx: &mut Context| -> Result<_> {
81            let mut new_unit = unit.clone();
82            let absorbed = absorb_statements(&new_body, &body.statements, ctx)?;
83            new_unit.body = match &absorbed.kind {
84                ExprKind::TypeLevelIf(_, _, _) => {
85                    let loc = absorbed.loc();
86                    ExprKind::Block(Box::new(Block {
87                        statements: vec![],
88                        result: Some(absorbed),
89                    }))
90                    .with_id(ctx.idtracker.next())
91                    .at_loc(&loc)
92                }
93                ExprKind::Block(_) => absorbed,
94                ExprKind::Error => absorbed,
95                _ => diag_bail!(absorbed, "Non tlif or body"),
96            };
97
98            let new_name = unit
99                .name
100                .name_id()
101                .1
102                .clone()
103                .push_ident(Identifier(name_suffix.to_string()).nowhere());
104            let new_nameid = ctx
105                .symtab
106                .add_thing(new_name, Thing::Unit(new_unit.head.clone().at_loc(&unit)));
107            new_unit.name = UnitName::WithID(new_nameid.clone().at_loc(&unit.head.name));
108
109            let new_unit = expand_type_level_if(new_unit, ctx)?;
110            ctx.item_list.add_executable(
111                new_nameid.clone().at_loc(&unit.head.name),
112                ExecutableItem::Unit(new_unit),
113            )?;
114
115            Ok(new_nameid.at_loc(&unit.head.name))
116        };
117
118    let call_expanded = |expanded_name, ctx: &mut Context| {
119        let kind = match &unit.head.unit_kind.inner {
120            UnitKind::Function(_) => CallKind::Function,
121            UnitKind::Entity => CallKind::Entity(().nowhere()),
122            UnitKind::Pipeline {
123                depth,
124                depth_typeexpr_id: _,
125            } => CallKind::Pipeline {
126                inst_loc: ().nowhere(),
127                depth: depth.clone(),
128                depth_typeexpr_id: ctx.idtracker.next(),
129            },
130        };
131
132        let args = ArgumentList::Positional(
133            unit.inputs
134                .iter()
135                .map(|(name, _)| {
136                    ExprKind::Identifier(name.inner.clone())
137                        .with_id(ctx.idtracker.next())
138                        .at_loc(&name)
139                })
140                .collect(),
141        )
142        .at_loc(&unit.head.inputs);
143
144        let turbofish = if !unit.head.unit_type_params.is_empty() {
145            Some(
146                ArgumentList::Positional(
147                    unit.head
148                        .unit_type_params
149                        .iter()
150                        .map(|p| {
151                            TypeExpression::TypeSpec(spade_hir::TypeSpec::Generic(
152                                p.name_id.clone().at_loc(p),
153                            ))
154                            .at_loc(p)
155                        })
156                        .collect(),
157                )
158                .at_loc(&unit),
159            )
160        } else {
161            None
162        };
163
164        ExprKind::Call {
165            kind,
166            callee: expanded_name,
167            args,
168            turbofish,
169        }
170        .with_id(ctx.idtracker.next())
171        .at_loc(&unit.body)
172    };
173
174    match body.result.as_ref().map(|e| &e.kind) {
175        Some(ExprKind::TypeLevelIf(cond, on_true, on_false)) => {
176            let on_true = expand_body(&on_true, "T", ctx)?;
177            let on_false = expand_body(&on_false, "F", ctx)?;
178
179            let new_on_true = call_expanded(on_true, ctx);
180            let new_on_false = call_expanded(on_false, ctx);
181
182            let new_result =
183                ExprKind::TypeLevelIf(cond.clone(), Box::new(new_on_true), Box::new(new_on_false))
184                    .with_id(ctx.idtracker.next())
185                    .at_loc(&unit.body);
186
187            let result_name = ctx
188                .symtab
189                .add_local_variable(Identifier("result".to_string()).at_loc(&unit));
190
191            let result_binding = Statement::Binding(Binding {
192                pattern: PatternKind::Name {
193                    name: result_name.clone().at_loc(&unit),
194                    pre_declared: false,
195                }
196                .with_id(ctx.idtracker.next())
197                .at_loc(&unit),
198                ty: None,
199                value: new_result,
200                wal_trace: None,
201            })
202            .at_loc(&unit);
203
204            let pipeline_depth = match &unit.head.unit_kind.inner {
205                UnitKind::Function(_) => None,
206                UnitKind::Entity => None,
207                UnitKind::Pipeline {
208                    depth,
209                    depth_typeexpr_id: _,
210                } => Some(depth),
211            };
212            let pipeline_reg = pipeline_depth
213                .map(|depth| {
214                    vec![Statement::PipelineRegMarker(Some(
215                        spade_hir::PipelineRegMarkerExtra::Count {
216                            count: depth.clone(),
217                            count_typeexpr_id: ctx.idtracker.next(),
218                        },
219                    ))
220                    .at_loc(&depth)]
221                })
222                .unwrap_or_default();
223
224            unit.body = ExprKind::Block(Box::new(Block {
225                statements: vec![result_binding]
226                    .into_iter()
227                    .chain(pipeline_reg)
228                    .collect(),
229                result: Some(
230                    ExprKind::Identifier(result_name)
231                        .with_id(ctx.idtracker.next())
232                        .at_loc(&unit),
233                ),
234            }))
235            .with_id(ctx.idtracker.next())
236            .at_loc(&unit.body);
237
238            Ok(expand_type_level_if(unit, ctx)?)
239        }
240        _ => Ok(unit),
241    }
242}