spade_ast_lowering/
lambda.rs

1use spade_ast as ast;
2use spade_common::location_info::WithLocation;
3use spade_common::name::Identifier;
4use spade_common::name::Path;
5use spade_diagnostics::diag_anyhow;
6use spade_diagnostics::diag_bail;
7use spade_diagnostics::Diagnostic;
8use spade_hir as hir;
9use spade_hir::expression::CapturedLambdaParam;
10use spade_types::meta_types::MetaType;
11
12use crate::global_symbols::re_visit_type_declaration;
13use crate::global_symbols::visit_type_declaration;
14use crate::impls::visit_impl;
15use crate::visit_block;
16use crate::visit_pattern;
17use crate::Context;
18use crate::LocExt;
19use crate::Result;
20
21/*
22
23ast lowering:
24- Add a type for the lambda
25    - Mono needs this type to be as generic as its context generic
26        - Function generics
27        - Impl block generics
28    - Add an impl block for `Fn<...>`
29        impl Fn<(Args), Output> for LambdaT {
30            fn call(self, args: Args) -> Output {
31                ... placeholder
32            }
33        }
34
35- Typechecking
36    - Typecheck the body as if it were inline
37
38- Post mono
39    - Replace lambda body placeholders
40
41```spade
42fn (a, b, c) {/* body */} =>  {LambdaDef<A, B, C, D>(), /* body */}
43
44// These are added
45struct Lambda<A, B, C, O> {}
46
47impl<A, B, C, O> Fn<(A, B, C), O> for Lambda<A, B, C, O> {
48    fn call(self, args: (A, B, C)) -> O {
49        // placeholder
50    }
51}
52
53// After typechecking we replace the placeholder body with the actual body
54
55```
56*/
57
58pub fn visit_lambda(e: &ast::Expression, ctx: &mut Context) -> Result<hir::ExprKind> {
59    let ast::Expression::Lambda {
60        unit_kind,
61        args,
62        body,
63    } = &e
64    else {
65        panic!("visit_lambda called with non-lambda");
66    };
67
68    let debug_loc = unit_kind.loc();
69    let loc = ().between_locs(unit_kind, body);
70
71    let type_name = Identifier(format!("Lambda"));
72    let output_type_name = Identifier("Output".to_string());
73
74    let current_unit = ctx.current_unit.clone().ok_or_else(|| {
75        diag_anyhow!(loc, "Did not have a current_unit when visiting this lambda")
76    })?;
77
78    let arg_output_generic_param_names = args
79        .iter()
80        .enumerate()
81        .map(|(i, arg)| Identifier(format!("A{i}")).at_loc(arg))
82        .chain(vec![output_type_name.clone().nowhere()])
83        .collect::<Vec<_>>();
84
85    let captured_generic_params = current_unit
86        .unit_type_params
87        .iter()
88        .chain(current_unit.scope_type_params.iter())
89        .cloned()
90        .collect::<Vec<_>>();
91
92    let all_generic_param_names = arg_output_generic_param_names
93        .clone()
94        .into_iter()
95        .chain(
96            captured_generic_params
97                .iter()
98                .map(|p| p.name_id().1.tail().at_loc(&p)),
99        )
100        .collect::<Vec<_>>();
101
102    let type_params = arg_output_generic_param_names
103        .iter()
104        .map(|name| {
105            ast::TypeParam::TypeName {
106                name: name.clone(),
107                traits: vec![],
108            }
109            .at_loc(name)
110        })
111        .chain(
112            captured_generic_params
113                .iter()
114                .map(|tp| {
115                    Ok(ast::TypeParam::TypeWithMeta {
116                        // NOTE: Recreating the meta-type like this is kind of strange, but works for now. If we add more meta-types in the future, recondsider this decision
117                        meta: match tp.meta {
118                            MetaType::Int => Identifier("int".to_string()).at_loc(tp),
119                            MetaType::Uint => Identifier("uint".to_string()).at_loc(tp),
120                            MetaType::Bool => Identifier("bool".to_string()).at_loc(tp),
121                            MetaType::Any | MetaType::Type | MetaType::Number => {
122                                diag_bail!(loc, "Found unexpected meta in captured type args")
123                            }
124                        },
125                        name: tp.name_id().1.tail().at_loc(tp),
126                    }
127                    .at_loc(tp))
128                })
129                .collect::<Result<Vec<_>>>()?
130                .into_iter(),
131        )
132        .collect::<Vec<_>>()
133        .at_loc(&debug_loc);
134
135    let args_spec = ast::TypeSpec::Tuple(
136        args.iter()
137            .enumerate()
138            .map(|(i, arg)| {
139                ast::TypeExpression::TypeSpec(Box::new(
140                    ast::TypeSpec::Named(
141                        Path::ident(Identifier(format!("A{i}")).at_loc(arg)).at_loc(arg),
142                        None,
143                    )
144                    .nowhere(),
145                ))
146                .at_loc(arg)
147            })
148            .collect::<Vec<_>>(),
149    )
150    .nowhere();
151
152    let type_decl = ast::TypeDeclaration {
153        name: type_name.clone().at_loc(&debug_loc),
154        kind: spade_ast::TypeDeclKind::Struct(
155            ast::Struct {
156                attributes: ast::AttributeList::empty(),
157                name: type_name.clone().at_loc(&debug_loc),
158                members: ast::ParameterList::without_self(vec![]).at_loc(&debug_loc),
159                port_keyword: None,
160            }
161            .at_loc(&debug_loc),
162        ),
163        generic_args: Some(type_params.clone()),
164    }
165    .at_loc(&debug_loc);
166
167    ctx.in_fresh_unit(|ctx| visit_type_declaration(&type_decl, ctx))?;
168    ctx.in_fresh_unit(|ctx| re_visit_type_declaration(&type_decl, ctx))?;
169
170    let impl_block = ast::ImplBlock {
171        r#trait: Some(
172            ast::TraitSpec {
173                path: Path::from_strs(&["Fn"]).nowhere(),
174                type_params: Some(
175                    vec![
176                        ast::TypeExpression::TypeSpec(Box::new(args_spec.clone())).nowhere(),
177                        ast::TypeExpression::TypeSpec(Box::new(
178                            ast::TypeSpec::Named(
179                                Path::ident(output_type_name.clone().nowhere()).nowhere(),
180                                None,
181                            )
182                            .nowhere(),
183                        ))
184                        .nowhere(),
185                    ]
186                    .nowhere(),
187                ),
188            }
189            .at_loc(&debug_loc),
190        ),
191        type_params: Some(type_params),
192        where_clauses: vec![],
193        target: ast::TypeSpec::Named(
194            Path::ident(type_name.clone().nowhere()).nowhere(),
195            Some(
196                all_generic_param_names
197                    .iter()
198                    .map(|name| {
199                        ast::TypeExpression::TypeSpec(Box::new(
200                            ast::TypeSpec::Named(Path::ident(name.clone()).at_loc(name), None)
201                                .at_loc(name),
202                        ))
203                        .at_loc(name)
204                    })
205                    .collect::<Vec<_>>()
206                    .nowhere(),
207            ),
208        )
209        .nowhere(),
210        units: vec![ast::Unit {
211            head: ast::UnitHead {
212                extern_token: None,
213                attributes: ast::AttributeList(vec![]),
214                unit_kind: unit_kind.clone(),
215                name: Identifier("call".to_string()).nowhere(),
216                inputs: ast::ParameterList {
217                    self_: Some(().nowhere()),
218                    args: vec![(
219                        ast::AttributeList(vec![]),
220                        Identifier("args".to_string()).nowhere(),
221                        args_spec,
222                    )],
223                }
224                .nowhere(),
225                output_type: Some((
226                    ().nowhere(),
227                    ast::TypeSpec::Named(Path::ident(output_type_name.nowhere()).nowhere(), None)
228                        .nowhere(),
229                )),
230                type_params: None,
231                where_clauses: vec![],
232            },
233            body: Some(
234                ast::Expression::Block(Box::new(ast::Block {
235                    statements: vec![],
236                    result: Some(
237                        ast::Expression::StaticUnreachable(
238                            "Compiler bug: Lambda body was not lowered during monomorphization"
239                                .to_string()
240                                .at_loc(body),
241                        )
242                        .at_loc(body),
243                    ),
244                }))
245                .at_loc(body),
246            ),
247        }
248        .at_loc(&debug_loc)],
249    };
250
251    let lambda_unit = ctx.in_fresh_unit(|ctx| {
252        match visit_impl(&impl_block.at_loc(&debug_loc), ctx)?.as_slice() {
253            [item] => {
254                let u = item.assume_unit();
255                ctx.item_list.add_executable(
256                    u.name.name_id().clone(),
257                    hir::ExecutableItem::Unit(u.clone().at_loc(&loc)),
258                )?;
259                Ok::<_, Diagnostic>(u.clone())
260            }
261            _ => diag_bail!(loc, "Lambda impl block produced more than one item"),
262        }
263    })?;
264
265    let (callee_name, callee_struct) = ctx
266        .symtab
267        .lookup_struct(&Path::ident(type_name.at_loc(&debug_loc)).at_loc(&debug_loc))?;
268
269    ctx.symtab
270        .new_scope_with_barrier(Box::new(|name, previous, thing| match thing {
271            spade_hir::symbol_table::Thing::Variable(_) => {
272                Err(Diagnostic::error(name, "Lambda captures are not supported")
273                    .primary_label("This variable is captured")
274                    .secondary_label(previous, "The variable is defined outside the lambda here"))
275            }
276            spade_hir::symbol_table::Thing::PipelineStage(_) => Err(Diagnostic::error(
277                name,
278                "Pipeline stages cannot cross lambda functions",
279            )
280            .primary_label("Capturing a pipeline stage...")
281            .secondary_label(previous, "That is defined outside the lambda")),
282            spade_hir::symbol_table::Thing::Struct(_)
283            | spade_hir::symbol_table::Thing::EnumVariant(_)
284            | spade_hir::symbol_table::Thing::Unit(_)
285            | spade_hir::symbol_table::Thing::Alias {
286                path: _,
287                in_namespace: _,
288            }
289            | spade_hir::symbol_table::Thing::Module(_)
290            | spade_hir::symbol_table::Thing::Trait(_) => Ok(()),
291        }));
292    let arguments = args
293        .iter()
294        .map(|arg| arg.try_visit(visit_pattern, ctx))
295        .collect::<Result<Vec<_>>>()?;
296    let body = body.try_map_ref(|body| visit_block(body, ctx));
297    ctx.symtab.close_scope();
298    let body = Box::new(
299        body?.map(|body| hir::ExprKind::Block(Box::new(body)).with_id(ctx.idtracker.next())),
300    );
301
302    Ok(hir::ExprKind::LambdaDef {
303        lambda_type: callee_name,
304        lambda_type_params: callee_struct.type_params.clone(),
305        lambda_unit: lambda_unit.name.name_id().inner.clone(),
306        captured_generic_params: captured_generic_params
307            .iter()
308            // Kind of cursed, but we need to figure out what name IDs we gave to the captured
309            // arguments while visiting the unit so we can replace them later
310            .zip(
311                lambda_unit
312                    .head
313                    .scope_type_params
314                    .iter()
315                    .skip(arg_output_generic_param_names.len()),
316            )
317            .map(|(in_body, in_lambda)| CapturedLambdaParam {
318                name_in_lambda: in_lambda.name_id(),
319                name_in_body: in_body.name_id().at_loc(in_body),
320            })
321            .collect(),
322        arguments,
323        body,
324    })
325}