spade_hir_lowering/passes/
lower_lambda_defs.rs1use std::collections::HashMap;
2
3use spade_common::{
4 id_tracker::ExprIdTracker,
5 location_info::{Loc, WithLocation},
6 name::NameID,
7};
8use spade_diagnostics::{diag_anyhow, Diagnostic};
9use spade_hir::{
10 expression::CallKind, ArgumentList, ExprKind, Expression, Parameter, ParameterList, Pattern,
11 Statement, TypeParam, TypeSpec, Unit, UnitHead,
12};
13use spade_typeinference::{equation::KnownTypeVar, HasType, TypeState};
14
15use crate::error::Result;
16
17use super::pass::Pass;
18
19pub(crate) struct LambdaReplacement {
20 pub new_body: Loc<Expression>,
21 pub arguments: Vec<(Loc<Pattern>, KnownTypeVar)>,
22 pub captured_type_params: HashMap<NameID, NameID>,
23}
24
25impl LambdaReplacement {
26 fn replace_type_params(&self, old: &Vec<Loc<TypeParam>>) -> Vec<Loc<TypeParam>> {
27 old.clone()
28 .into_iter()
29 .map(|tp| {
30 let loc = tp.loc();
31 let TypeParam {
32 ident,
33 name_id,
34 trait_bounds,
35 meta,
36 } = tp.inner;
37 TypeParam {
38 name_id: self
39 .captured_type_params
40 .get(&name_id)
41 .cloned()
42 .unwrap_or(name_id),
43 ident,
44 trait_bounds,
45 meta,
46 }
47 .at_loc(&loc)
48 })
49 .collect::<Vec<_>>()
50 }
51
52 fn update_type_spec(&self, ts: Loc<TypeSpec>) -> Loc<TypeSpec> {
53 let mut new_ts = ts.clone();
54 for (from, to) in &self.captured_type_params {
55 new_ts = new_ts.map(|ty| {
56 ty.replace_in(
57 &TypeSpec::Generic(from.clone().at_loc(&ts)),
58 &TypeSpec::Generic(to.clone().at_loc(&ts)),
59 )
60 })
61 }
62 new_ts
63 }
64
65 pub fn replace_in(&self, old: Loc<Unit>, idtracker: &mut ExprIdTracker) -> Result<Loc<Unit>> {
66 let arg_bindings = self
67 .arguments
68 .iter()
69 .enumerate()
70 .map(|(i, (arg, _))| {
71 let (input, _) = old.inputs.get(1).ok_or_else(|| {
73 diag_anyhow!(
74 arg,
75 "Did not find any arguments to the generated lambda body"
76 )
77 })?;
78 Ok(Statement::binding(
79 arg.clone(),
80 None,
81 ExprKind::TupleIndex(
82 Box::new(
83 ExprKind::Identifier(input.clone().inner)
84 .with_id(idtracker.next())
85 .at_loc(arg),
86 ),
87 (i as u128).at_loc(arg),
88 )
89 .with_id(idtracker.next())
90 .at_loc(input),
91 )
92 .at_loc(arg))
93 })
94 .collect::<Result<Vec<_>>>()?;
95
96 let scope_type_params = self.replace_type_params(&old.head.scope_type_params);
97 let unit_type_params = self.replace_type_params(&old.head.unit_type_params);
98
99 let body = self.new_body.clone().map(|mut body| {
100 let block = body.assume_block_mut();
101
102 block.statements = arg_bindings
103 .clone()
104 .into_iter()
105 .chain(block.statements.clone())
106 .collect::<Vec<_>>();
107
108 body
109 });
110
111 let result = old.map_ref(move |unit| spade_hir::Unit {
112 body: body.clone(),
113 inputs: unit
114 .inputs
115 .iter()
116 .map(|(n, t)| (n.clone(), self.update_type_spec(t.clone())))
117 .collect(),
118 head: UnitHead {
119 scope_type_params: scope_type_params.clone(),
120 unit_type_params: unit_type_params.clone(),
121 inputs: ParameterList(
122 unit.head
123 .inputs
124 .0
125 .iter()
126 .cloned()
127 .map(|i| Parameter {
128 no_mangle: i.no_mangle,
129 name: i.name,
130 ty: self.update_type_spec(i.ty),
131 })
132 .collect(),
133 )
134 .at_loc(&unit.head.inputs),
135 ..unit.head.clone()
136 },
137 ..unit.clone()
138 });
139 Ok(result)
140 }
141}
142
143pub(crate) struct LowerLambdaDefs<'a> {
144 pub type_state: &'a TypeState,
145
146 pub replacements: &'a mut HashMap<NameID, LambdaReplacement>,
147}
148
149impl<'a> Pass for LowerLambdaDefs<'a> {
150 fn visit_expression(&mut self, expression: &mut Loc<Expression>) -> Result<()> {
151 if let ExprKind::LambdaDef {
152 lambda_unit,
153 lambda_type,
154 lambda_type_params: _,
155 captured_generic_params,
156 arguments,
157 body,
158 } = &expression.kind
159 {
160 let arguments = arguments
161 .iter()
162 .cloned()
163 .map(|arg| {
164 let ty = arg
165 .get_type(&self.type_state)
166 .resolve(&self.type_state)
167 .into_known(&self.type_state)
168 .ok_or_else(|| {
169 Diagnostic::error(&arg, "The type of this argument is not fully known")
170 .primary_label("Type is not fully known")
171 })?;
172 Ok((arg, ty))
173 })
174 .collect::<Result<Vec<_>>>()?;
175
176 self.replacements.insert(
177 lambda_unit.clone(),
178 LambdaReplacement {
179 new_body: body.as_ref().clone(),
180 arguments: arguments.clone(),
181 captured_type_params: captured_generic_params
182 .iter()
183 .map(|tp| (tp.name_in_lambda.clone(), tp.name_in_body.inner.clone()))
184 .collect(),
185 },
186 );
187
188 *expression = ExprKind::Call {
189 kind: CallKind::Function,
190 callee: lambda_type.clone().at_loc(expression),
191 args: ArgumentList::Positional(vec![]).at_loc(expression),
192 turbofish: None,
193 }
194 .with_id(expression.id)
195 .at_loc(expression);
196
197 Ok(())
198 } else {
199 Ok(())
200 }
201 }
202
203 fn visit_unit(&mut self, _unit: &mut Unit) -> Result<()> {
204 Ok(())
205 }
206}