spade_ast_lowering/
type_level_if.rs
1use spade_common::location_info::Loc;
2use spade_common::location_info::WithLocation;
3use spade_common::name::Identifier;
4use spade_diagnostics::diag_bail;
5use spade_diagnostics::Diagnostic;
6use spade_hir::expression::CallKind;
7use spade_hir::symbol_table::Thing;
8use spade_hir::ArgumentList;
9use spade_hir::Binding;
10use spade_hir::Block;
11use spade_hir::ExecutableItem;
12use spade_hir::Expression;
13use spade_hir::PatternKind;
14use spade_hir::Statement;
15use spade_hir::TypeExpression;
16use spade_hir::UnitKind;
17use spade_hir::UnitName;
18use spade_hir::{ExprKind, Unit};
19
20use crate::Context;
21use crate::Result;
22
23pub fn absorb_statements(
44 body: &Loc<Expression>,
45 outer_statements: &Vec<Loc<Statement>>,
46 ctx: &mut Context,
47) -> Result<Loc<Expression>> {
48 body.try_map_ref(|expr| match &expr.kind {
49 ExprKind::TypeLevelIf(cond, on_true, on_false) => Ok(ExprKind::TypeLevelIf(
50 cond.clone(),
51 Box::new(absorb_statements(on_true, outer_statements, ctx)?),
52 Box::new(absorb_statements(on_false, outer_statements, ctx)?),
53 )
54 .with_id(ctx.idtracker.next())),
55 ExprKind::Block(block) => Ok(ExprKind::Block(Box::new(Block {
56 statements: outer_statements
57 .iter()
58 .chain(block.statements.iter())
59 .cloned()
60 .collect(),
61 result: block.result.clone(),
62 }))
63 .with_id(ctx.idtracker.next())),
64 ExprKind::Error => Ok(ExprKind::Error.with_id(ctx.idtracker.next())),
65 _ => Err(Diagnostic::bug(
66 body,
67 "The body of a gen if can only be a block or another gen if",
68 )
69 .primary_label(format!("Invalid body of gen if"))),
70 })
71}
72
73pub fn expand_type_level_if(mut unit: Loc<Unit>, ctx: &mut Context) -> Result<Loc<Unit>> {
74 let Ok(body) = unit.body.assume_block() else {
75 unit.body.kind = ExprKind::Error;
76 return Ok(unit);
77 };
78
79 let expand_body =
80 |new_body: &Loc<Expression>, name_suffix: &str, ctx: &mut Context| -> Result<_> {
81 let mut new_unit = unit.clone();
82 let absorbed = absorb_statements(&new_body, &body.statements, ctx)?;
83 new_unit.body = match &absorbed.kind {
84 ExprKind::TypeLevelIf(_, _, _) => {
85 let loc = absorbed.loc();
86 ExprKind::Block(Box::new(Block {
87 statements: vec![],
88 result: Some(absorbed),
89 }))
90 .with_id(ctx.idtracker.next())
91 .at_loc(&loc)
92 }
93 ExprKind::Block(_) => absorbed,
94 ExprKind::Error => absorbed,
95 _ => diag_bail!(absorbed, "Non tlif or body"),
96 };
97
98 let new_name = unit
99 .name
100 .name_id()
101 .1
102 .clone()
103 .push_ident(Identifier(name_suffix.to_string()).nowhere());
104 let new_nameid = ctx
105 .symtab
106 .add_thing(new_name, Thing::Unit(new_unit.head.clone().at_loc(&unit)));
107 new_unit.name = UnitName::WithID(new_nameid.clone().at_loc(&unit.head.name));
108
109 let new_unit = expand_type_level_if(new_unit, ctx)?;
110 ctx.item_list.add_executable(
111 new_nameid.clone().at_loc(&unit.head.name),
112 ExecutableItem::Unit(new_unit),
113 )?;
114
115 Ok(new_nameid.at_loc(&unit.head.name))
116 };
117
118 let call_expanded = |expanded_name, ctx: &mut Context| {
119 let kind = match &unit.head.unit_kind.inner {
120 UnitKind::Function(_) => CallKind::Function,
121 UnitKind::Entity => CallKind::Entity(().nowhere()),
122 UnitKind::Pipeline {
123 depth,
124 depth_typeexpr_id: _,
125 } => CallKind::Pipeline {
126 inst_loc: ().nowhere(),
127 depth: depth.clone(),
128 depth_typeexpr_id: ctx.idtracker.next(),
129 },
130 };
131
132 let args = ArgumentList::Positional(
133 unit.inputs
134 .iter()
135 .map(|(name, _)| {
136 ExprKind::Identifier(name.inner.clone())
137 .with_id(ctx.idtracker.next())
138 .at_loc(&name)
139 })
140 .collect(),
141 )
142 .at_loc(&unit.head.inputs);
143
144 let turbofish = if !unit.head.unit_type_params.is_empty() {
145 Some(
146 ArgumentList::Positional(
147 unit.head
148 .unit_type_params
149 .iter()
150 .map(|p| {
151 TypeExpression::TypeSpec(spade_hir::TypeSpec::Generic(
152 p.name_id.clone().at_loc(p),
153 ))
154 .at_loc(p)
155 })
156 .collect(),
157 )
158 .at_loc(&unit),
159 )
160 } else {
161 None
162 };
163
164 ExprKind::Call {
165 kind,
166 callee: expanded_name,
167 args,
168 turbofish,
169 }
170 .with_id(ctx.idtracker.next())
171 .at_loc(&unit.body)
172 };
173
174 match body.result.as_ref().map(|e| &e.kind) {
175 Some(ExprKind::TypeLevelIf(cond, on_true, on_false)) => {
176 let on_true = expand_body(&on_true, "T", ctx)?;
177 let on_false = expand_body(&on_false, "F", ctx)?;
178
179 let new_on_true = call_expanded(on_true, ctx);
180 let new_on_false = call_expanded(on_false, ctx);
181
182 let new_result =
183 ExprKind::TypeLevelIf(cond.clone(), Box::new(new_on_true), Box::new(new_on_false))
184 .with_id(ctx.idtracker.next())
185 .at_loc(&unit.body);
186
187 let result_name = ctx
188 .symtab
189 .add_local_variable(Identifier("result".to_string()).at_loc(&unit));
190
191 let result_binding = Statement::Binding(Binding {
192 pattern: PatternKind::Name {
193 name: result_name.clone().at_loc(&unit),
194 pre_declared: false,
195 }
196 .with_id(ctx.idtracker.next())
197 .at_loc(&unit),
198 ty: None,
199 value: new_result,
200 wal_trace: None,
201 })
202 .at_loc(&unit);
203
204 let pipeline_depth = match &unit.head.unit_kind.inner {
205 UnitKind::Function(_) => None,
206 UnitKind::Entity => None,
207 UnitKind::Pipeline {
208 depth,
209 depth_typeexpr_id: _,
210 } => Some(depth),
211 };
212 let pipeline_reg = pipeline_depth
213 .map(|depth| {
214 vec![Statement::PipelineRegMarker(Some(
215 spade_hir::PipelineRegMarkerExtra::Count {
216 count: depth.clone(),
217 count_typeexpr_id: ctx.idtracker.next(),
218 },
219 ))
220 .at_loc(&depth)]
221 })
222 .unwrap_or_default();
223
224 unit.body = ExprKind::Block(Box::new(Block {
225 statements: vec![result_binding]
226 .into_iter()
227 .chain(pipeline_reg)
228 .collect(),
229 result: Some(
230 ExprKind::Identifier(result_name)
231 .with_id(ctx.idtracker.next())
232 .at_loc(&unit),
233 ),
234 }))
235 .with_id(ctx.idtracker.next())
236 .at_loc(&unit.body);
237
238 Ok(expand_type_level_if(unit, ctx)?)
239 }
240 _ => Ok(unit),
241 }
242}