spade_hir_lowering/passes/
lower_lambda_defs.rs

1use std::collections::HashMap;
2
3use spade_common::{
4    id_tracker::ExprIdTracker,
5    location_info::{Loc, WithLocation},
6    name::NameID,
7};
8use spade_diagnostics::{diag_anyhow, Diagnostic};
9use spade_hir::{
10    expression::CallKind, ArgumentList, ExprKind, Expression, Parameter, ParameterList, Pattern,
11    Statement, TypeParam, TypeSpec, Unit, UnitHead,
12};
13use spade_typeinference::{equation::KnownTypeVar, HasType, TypeState};
14
15use crate::error::Result;
16
17use super::pass::Pass;
18
19pub(crate) struct LambdaReplacement {
20    pub new_body: Loc<Expression>,
21    pub arguments: Vec<(Loc<Pattern>, KnownTypeVar)>,
22    pub captured_type_params: HashMap<NameID, NameID>,
23}
24
25impl LambdaReplacement {
26    fn replace_type_params(&self, old: &Vec<Loc<TypeParam>>) -> Vec<Loc<TypeParam>> {
27        old.clone()
28            .into_iter()
29            .map(|tp| {
30                let loc = tp.loc();
31                let TypeParam {
32                    ident,
33                    name_id,
34                    trait_bounds,
35                    meta,
36                } = tp.inner;
37                TypeParam {
38                    name_id: self
39                        .captured_type_params
40                        .get(&name_id)
41                        .cloned()
42                        .unwrap_or(name_id),
43                    ident,
44                    trait_bounds,
45                    meta,
46                }
47                .at_loc(&loc)
48            })
49            .collect::<Vec<_>>()
50    }
51
52    fn update_type_spec(&self, ts: Loc<TypeSpec>) -> Loc<TypeSpec> {
53        let mut new_ts = ts.clone();
54        for (from, to) in &self.captured_type_params {
55            new_ts = new_ts.map(|ty| {
56                ty.replace_in(
57                    &TypeSpec::Generic(from.clone().at_loc(&ts)),
58                    &TypeSpec::Generic(to.clone().at_loc(&ts)),
59                )
60            })
61        }
62        new_ts
63    }
64
65    pub fn replace_in(&self, old: Loc<Unit>, idtracker: &mut ExprIdTracker) -> Result<Loc<Unit>> {
66        let arg_bindings = self
67            .arguments
68            .iter()
69            .enumerate()
70            .map(|(i, (arg, _))| {
71                // .1, .0 is self
72                let (input, _) = old.inputs.get(1).ok_or_else(|| {
73                    diag_anyhow!(
74                        arg,
75                        "Did not find any arguments to the generated lambda body"
76                    )
77                })?;
78                Ok(Statement::binding(
79                    arg.clone(),
80                    None,
81                    ExprKind::TupleIndex(
82                        Box::new(
83                            ExprKind::Identifier(input.clone().inner)
84                                .with_id(idtracker.next())
85                                .at_loc(arg),
86                        ),
87                        (i as u128).at_loc(arg),
88                    )
89                    .with_id(idtracker.next())
90                    .at_loc(input),
91                )
92                .at_loc(arg))
93            })
94            .collect::<Result<Vec<_>>>()?;
95
96        let scope_type_params = self.replace_type_params(&old.head.scope_type_params);
97        let unit_type_params = self.replace_type_params(&old.head.unit_type_params);
98
99        let body = self.new_body.clone().map(|mut body| {
100            let block = body.assume_block_mut();
101
102            block.statements = arg_bindings
103                .clone()
104                .into_iter()
105                .chain(block.statements.clone())
106                .collect::<Vec<_>>();
107
108            body
109        });
110
111        let result = old.map_ref(move |unit| spade_hir::Unit {
112            body: body.clone(),
113            inputs: unit
114                .inputs
115                .iter()
116                .map(|(n, t)| (n.clone(), self.update_type_spec(t.clone())))
117                .collect(),
118            head: UnitHead {
119                scope_type_params: scope_type_params.clone(),
120                unit_type_params: unit_type_params.clone(),
121                inputs: ParameterList(
122                    unit.head
123                        .inputs
124                        .0
125                        .iter()
126                        .cloned()
127                        .map(|i| Parameter {
128                            no_mangle: i.no_mangle,
129                            name: i.name,
130                            ty: self.update_type_spec(i.ty),
131                        })
132                        .collect(),
133                )
134                .at_loc(&unit.head.inputs),
135                ..unit.head.clone()
136            },
137            ..unit.clone()
138        });
139        Ok(result)
140    }
141}
142
143pub(crate) struct LowerLambdaDefs<'a> {
144    pub type_state: &'a TypeState,
145
146    pub replacements: &'a mut HashMap<NameID, LambdaReplacement>,
147}
148
149impl<'a> Pass for LowerLambdaDefs<'a> {
150    fn visit_expression(&mut self, expression: &mut Loc<Expression>) -> Result<()> {
151        if let ExprKind::LambdaDef {
152            lambda_unit,
153            lambda_type,
154            lambda_type_params: _,
155            captured_generic_params,
156            arguments,
157            body,
158        } = &expression.kind
159        {
160            let arguments = arguments
161                .iter()
162                .cloned()
163                .map(|arg| {
164                    let ty = arg
165                        .get_type(&self.type_state)
166                        .resolve(&self.type_state)
167                        .into_known(&self.type_state)
168                        .ok_or_else(|| {
169                            Diagnostic::error(&arg, "The type of this argument is not fully known")
170                                .primary_label("Type is not fully known")
171                        })?;
172                    Ok((arg, ty))
173                })
174                .collect::<Result<Vec<_>>>()?;
175
176            self.replacements.insert(
177                lambda_unit.clone(),
178                LambdaReplacement {
179                    new_body: body.as_ref().clone(),
180                    arguments: arguments.clone(),
181                    captured_type_params: captured_generic_params
182                        .iter()
183                        .map(|tp| (tp.name_in_lambda.clone(), tp.name_in_body.inner.clone()))
184                        .collect(),
185                },
186            );
187
188            *expression = ExprKind::Call {
189                kind: CallKind::Function,
190                callee: lambda_type.clone().at_loc(expression),
191                args: ArgumentList::Positional(vec![]).at_loc(expression),
192                turbofish: None,
193            }
194            .with_id(expression.id)
195            .at_loc(expression);
196
197            Ok(())
198        } else {
199            Ok(())
200        }
201    }
202
203    fn visit_unit(&mut self, _unit: &mut Unit) -> Result<()> {
204        Ok(())
205    }
206}