1use spade_ast as ast;
2use spade_common::location_info::WithLocation;
3use spade_common::name::Identifier;
4use spade_common::name::Path;
5use spade_diagnostics::diag_anyhow;
6use spade_diagnostics::diag_bail;
7use spade_diagnostics::Diagnostic;
8use spade_hir as hir;
9use spade_hir::expression::CapturedLambdaParam;
10use spade_types::meta_types::MetaType;
11
12use crate::global_symbols::re_visit_type_declaration;
13use crate::global_symbols::visit_type_declaration;
14use crate::impls::visit_impl;
15use crate::visit_block;
16use crate::visit_pattern;
17use crate::Context;
18use crate::LocExt;
19use crate::Result;
20
21pub fn visit_lambda(e: &ast::Expression, ctx: &mut Context) -> Result<hir::ExprKind> {
59 let ast::Expression::Lambda {
60 unit_kind,
61 args,
62 body,
63 } = &e
64 else {
65 panic!("visit_lambda called with non-lambda");
66 };
67
68 let debug_loc = unit_kind.loc();
69 let loc = ().between_locs(unit_kind, body);
70
71 let type_name = Identifier(format!("Lambda"));
72 let output_type_name = Identifier("Output".to_string());
73
74 let current_unit = ctx.current_unit.clone().ok_or_else(|| {
75 diag_anyhow!(loc, "Did not have a current_unit when visiting this lambda")
76 })?;
77
78 let arg_output_generic_param_names = args
79 .iter()
80 .enumerate()
81 .map(|(i, arg)| Identifier(format!("A{i}")).at_loc(arg))
82 .chain(vec![output_type_name.clone().nowhere()])
83 .collect::<Vec<_>>();
84
85 let captured_generic_params = current_unit
86 .unit_type_params
87 .iter()
88 .chain(current_unit.scope_type_params.iter())
89 .cloned()
90 .collect::<Vec<_>>();
91
92 let all_generic_param_names = arg_output_generic_param_names
93 .clone()
94 .into_iter()
95 .chain(
96 captured_generic_params
97 .iter()
98 .map(|p| p.name_id().1.tail().at_loc(&p)),
99 )
100 .collect::<Vec<_>>();
101
102 let type_params = arg_output_generic_param_names
103 .iter()
104 .map(|name| {
105 ast::TypeParam::TypeName {
106 name: name.clone(),
107 traits: vec![],
108 }
109 .at_loc(name)
110 })
111 .chain(
112 captured_generic_params
113 .iter()
114 .map(|tp| {
115 Ok(ast::TypeParam::TypeWithMeta {
116 meta: match tp.meta {
118 MetaType::Int => Identifier("int".to_string()).at_loc(tp),
119 MetaType::Uint => Identifier("uint".to_string()).at_loc(tp),
120 MetaType::Bool => Identifier("bool".to_string()).at_loc(tp),
121 MetaType::Any | MetaType::Type | MetaType::Number => {
122 diag_bail!(loc, "Found unexpected meta in captured type args")
123 }
124 },
125 name: tp.name_id().1.tail().at_loc(tp),
126 }
127 .at_loc(tp))
128 })
129 .collect::<Result<Vec<_>>>()?
130 .into_iter(),
131 )
132 .collect::<Vec<_>>()
133 .at_loc(&debug_loc);
134
135 let args_spec = ast::TypeSpec::Tuple(
136 args.iter()
137 .enumerate()
138 .map(|(i, arg)| {
139 ast::TypeExpression::TypeSpec(Box::new(
140 ast::TypeSpec::Named(
141 Path::ident(Identifier(format!("A{i}")).at_loc(arg)).at_loc(arg),
142 None,
143 )
144 .nowhere(),
145 ))
146 .at_loc(arg)
147 })
148 .collect::<Vec<_>>(),
149 )
150 .nowhere();
151
152 let type_decl = ast::TypeDeclaration {
153 name: type_name.clone().at_loc(&debug_loc),
154 kind: spade_ast::TypeDeclKind::Struct(
155 ast::Struct {
156 attributes: ast::AttributeList::empty(),
157 name: type_name.clone().at_loc(&debug_loc),
158 members: ast::ParameterList::without_self(vec![]).at_loc(&debug_loc),
159 port_keyword: None,
160 }
161 .at_loc(&debug_loc),
162 ),
163 generic_args: Some(type_params.clone()),
164 }
165 .at_loc(&debug_loc);
166
167 ctx.in_fresh_unit(|ctx| visit_type_declaration(&type_decl, ctx))?;
168 ctx.in_fresh_unit(|ctx| re_visit_type_declaration(&type_decl, ctx))?;
169
170 let impl_block = ast::ImplBlock {
171 r#trait: Some(
172 ast::TraitSpec {
173 path: Path::from_strs(&["Fn"]).nowhere(),
174 type_params: Some(
175 vec![
176 ast::TypeExpression::TypeSpec(Box::new(args_spec.clone())).nowhere(),
177 ast::TypeExpression::TypeSpec(Box::new(
178 ast::TypeSpec::Named(
179 Path::ident(output_type_name.clone().nowhere()).nowhere(),
180 None,
181 )
182 .nowhere(),
183 ))
184 .nowhere(),
185 ]
186 .nowhere(),
187 ),
188 }
189 .at_loc(&debug_loc),
190 ),
191 type_params: Some(type_params),
192 where_clauses: vec![],
193 target: ast::TypeSpec::Named(
194 Path::ident(type_name.clone().nowhere()).nowhere(),
195 Some(
196 all_generic_param_names
197 .iter()
198 .map(|name| {
199 ast::TypeExpression::TypeSpec(Box::new(
200 ast::TypeSpec::Named(Path::ident(name.clone()).at_loc(name), None)
201 .at_loc(name),
202 ))
203 .at_loc(name)
204 })
205 .collect::<Vec<_>>()
206 .nowhere(),
207 ),
208 )
209 .nowhere(),
210 units: vec![ast::Unit {
211 head: ast::UnitHead {
212 extern_token: None,
213 attributes: ast::AttributeList(vec![]),
214 unit_kind: unit_kind.clone(),
215 name: Identifier("call".to_string()).nowhere(),
216 inputs: ast::ParameterList {
217 self_: Some(().nowhere()),
218 args: vec![(
219 ast::AttributeList(vec![]),
220 Identifier("args".to_string()).nowhere(),
221 args_spec,
222 )],
223 }
224 .nowhere(),
225 output_type: Some((
226 ().nowhere(),
227 ast::TypeSpec::Named(Path::ident(output_type_name.nowhere()).nowhere(), None)
228 .nowhere(),
229 )),
230 type_params: None,
231 where_clauses: vec![],
232 },
233 body: Some(
234 ast::Expression::Block(Box::new(ast::Block {
235 statements: vec![],
236 result: Some(
237 ast::Expression::StaticUnreachable(
238 "Compiler bug: Lambda body was not lowered during monomorphization"
239 .to_string()
240 .at_loc(body),
241 )
242 .at_loc(body),
243 ),
244 }))
245 .at_loc(body),
246 ),
247 }
248 .at_loc(&debug_loc)],
249 };
250
251 let lambda_unit = ctx.in_fresh_unit(|ctx| {
252 match visit_impl(&impl_block.at_loc(&debug_loc), ctx)?.as_slice() {
253 [item] => {
254 let u = item.assume_unit();
255 ctx.item_list.add_executable(
256 u.name.name_id().clone(),
257 hir::ExecutableItem::Unit(u.clone().at_loc(&loc)),
258 )?;
259 Ok::<_, Diagnostic>(u.clone())
260 }
261 _ => diag_bail!(loc, "Lambda impl block produced more than one item"),
262 }
263 })?;
264
265 let (callee_name, callee_struct) = ctx
266 .symtab
267 .lookup_struct(&Path::ident(type_name.at_loc(&debug_loc)).at_loc(&debug_loc))?;
268
269 ctx.symtab
270 .new_scope_with_barrier(Box::new(|name, previous, thing| match thing {
271 spade_hir::symbol_table::Thing::Variable(_) => {
272 Err(Diagnostic::error(name, "Lambda captures are not supported")
273 .primary_label("This variable is captured")
274 .secondary_label(previous, "The variable is defined outside the lambda here"))
275 }
276 spade_hir::symbol_table::Thing::PipelineStage(_) => Err(Diagnostic::error(
277 name,
278 "Pipeline stages cannot cross lambda functions",
279 )
280 .primary_label("Capturing a pipeline stage...")
281 .secondary_label(previous, "That is defined outside the lambda")),
282 spade_hir::symbol_table::Thing::Struct(_)
283 | spade_hir::symbol_table::Thing::EnumVariant(_)
284 | spade_hir::symbol_table::Thing::Unit(_)
285 | spade_hir::symbol_table::Thing::Alias {
286 path: _,
287 in_namespace: _,
288 }
289 | spade_hir::symbol_table::Thing::Module(_)
290 | spade_hir::symbol_table::Thing::Trait(_) => Ok(()),
291 }));
292 let arguments = args
293 .iter()
294 .map(|arg| arg.try_visit(visit_pattern, ctx))
295 .collect::<Result<Vec<_>>>()?;
296 let body = body.try_map_ref(|body| visit_block(body, ctx));
297 ctx.symtab.close_scope();
298 let body = Box::new(
299 body?.map(|body| hir::ExprKind::Block(Box::new(body)).with_id(ctx.idtracker.next())),
300 );
301
302 Ok(hir::ExprKind::LambdaDef {
303 lambda_type: callee_name,
304 lambda_type_params: callee_struct.type_params.clone(),
305 lambda_unit: lambda_unit.name.name_id().inner.clone(),
306 captured_generic_params: captured_generic_params
307 .iter()
308 .zip(
311 lambda_unit
312 .head
313 .scope_type_params
314 .iter()
315 .skip(arg_output_generic_param_names.len()),
316 )
317 .map(|(in_body, in_lambda)| CapturedLambdaParam {
318 name_in_lambda: in_lambda.name_id(),
319 name_in_body: in_body.name_id().at_loc(in_body),
320 })
321 .collect(),
322 arguments,
323 body,
324 })
325}