spade_hir_lowering/
monomorphisation.rs

1use std::collections::{BTreeMap, HashMap, VecDeque};
2
3use itertools::Itertools;
4use mir::passes::MirPass;
5use spade_common::location_info::Loc;
6use spade_common::{id_tracker::ExprIdTracker, location_info::WithLocation, name::NameID};
7use spade_diagnostics::diagnostic::{Message, Subdiagnostic};
8use spade_diagnostics::{diag_anyhow, DiagHandler, Diagnostic};
9use spade_hir::{symbol_table::FrozenSymtab, ExecutableItem, ItemList, UnitName};
10use spade_mir as mir;
11use spade_typeinference::equation::KnownTypeVar;
12use spade_typeinference::error::UnificationErrorExt;
13use spade_typeinference::trace_stack::{format_trace_stack, TraceStackEntry};
14use spade_typeinference::{GenericListToken, HasType, TypeState};
15
16use crate::error::Result;
17use crate::generate_unit;
18use crate::name_map::NameSourceMap;
19use crate::passes::disallow_inout_bindings::InOutChecks;
20use crate::passes::disallow_zero_size::DisallowZeroSize;
21use crate::passes::flatten_regs::FlattenRegs;
22use crate::passes::lower_lambda_defs::{LambdaReplacement, LowerLambdaDefs};
23use crate::passes::lower_methods::LowerMethods;
24use crate::passes::lower_type_level_if::LowerTypeLevelIf;
25use crate::passes::pass::{Pass, Passable};
26
27/// An item to be monomorphised
28#[derive(Clone, Hash, PartialEq, Eq)]
29pub struct MonoItem {
30    /// The name of the original item which this is a monomorphised version of
31    pub source_name: Loc<NameID>,
32    /// The new name of the new item
33    pub new_name: UnitName,
34    /// The types to replace the generic types in the item. Positional replacement.
35    /// These are TypeVars which have to be fully known
36    pub params: Vec<KnownTypeVar>,
37}
38
39pub struct MonoState {
40    /// List of mono items left to compile
41    to_compile: VecDeque<MonoItem>,
42    /// Mapping between items with types specified and names
43    translation: BTreeMap<(NameID, Vec<KnownTypeVar>), NameID>,
44    /// Locations in the code where compilation of the Mono item was requested. None
45    /// if this is non-generic
46    request_points: HashMap<MonoItem, Option<(MonoItem, Loc<()>)>>,
47}
48
49impl Default for MonoState {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl MonoState {
56    pub fn new() -> MonoState {
57        MonoState {
58            to_compile: VecDeque::new(),
59            translation: BTreeMap::new(),
60            request_points: HashMap::new(),
61        }
62    }
63
64    /// Request compilation of a unit with the specified type parameters, returning the name of the
65    /// unit which will be compiled with these parameters. It is up to the caller of this
66    /// function to ensure that the type params are valid for this item.
67    pub fn request_compilation(
68        &mut self,
69        source_name: UnitName,
70        reuse_nameid: bool,
71        params: Vec<KnownTypeVar>,
72        symtab: &mut FrozenSymtab,
73        request_point: Option<(MonoItem, Loc<()>)>,
74    ) -> NameID {
75        match self
76            .translation
77            .get(&(source_name.name_id().inner.clone(), params.clone()))
78        {
79            Some(prev) => prev.clone(),
80            None => {
81                let new_name = if reuse_nameid {
82                    source_name.name_id().inner.clone()
83                } else {
84                    symtab.new_name(source_name.name_id().1.clone())
85                };
86
87                // Wrap the new name in a UnitName to match the source. Previous steps
88                // ensure that the unit name is general enough to not cause name collisions
89                let new_unit_name = match &source_name {
90                    UnitName::WithID(_) => UnitName::WithID(new_name.clone().nowhere()),
91                    UnitName::FullPath(_) => UnitName::FullPath(new_name.clone().nowhere()),
92                    UnitName::Unmangled(source, _) => {
93                        UnitName::Unmangled(source.clone(), new_name.clone().nowhere())
94                    }
95                };
96
97                let item = MonoItem {
98                    source_name: source_name.name_id().clone(),
99                    new_name: new_unit_name,
100                    params: params.clone(),
101                };
102                self.request_points.insert(item.clone(), request_point);
103
104                self.translation.insert(
105                    (source_name.name_id().inner.clone(), params.clone()),
106                    new_name.clone(),
107                );
108                self.to_compile.push_back(item);
109                new_name
110            }
111        }
112    }
113
114    fn next_target(&mut self) -> Option<MonoItem> {
115        self.to_compile.pop_front()
116    }
117
118    fn add_mono_traceback(&self, diagnostic: Diagnostic, item: &MonoItem) -> Diagnostic {
119        let parent = self.request_points.get(item).and_then(|x| x.clone());
120        if let Some((next_parent, loc)) = parent {
121            let generic_params = item.params.iter().map(|p| format!("{p}")).join(", ");
122
123            let new = diagnostic.subdiagnostic(Subdiagnostic::TemplateTraceback {
124                span: loc.into(),
125                message: Message::from(format!("{}<{}>", item.source_name, generic_params)),
126            });
127
128            self.add_mono_traceback(new, &next_parent)
129        } else {
130            diagnostic
131        }
132    }
133}
134
135pub struct MirOutput {
136    pub mir: mir::Entity,
137    pub type_state: TypeState,
138    /// Mapping between new names for registers and their previous value. Used
139    /// to add type information for registers generated by pipelines
140    pub reg_name_map: BTreeMap<NameID, NameID>,
141}
142
143pub fn compile_items(
144    items: &BTreeMap<&NameID, (&ExecutableItem, TypeState)>,
145    symtab: &mut FrozenSymtab,
146    idtracker: &mut ExprIdTracker,
147    name_source_map: &mut NameSourceMap,
148    item_list: &ItemList,
149    diag_handler: &mut DiagHandler,
150    opt_passes: &[&dyn MirPass],
151    impl_type_state: &TypeState,
152) -> Vec<Result<MirOutput>> {
153    // Build a map of items to use for compilation later. Also push all non
154    // generic items to the compilation queue
155
156    let mut state = MonoState::new();
157
158    for (item, _) in items.values() {
159        match item {
160            ExecutableItem::Unit(u) => {
161                if u.head.get_type_params().is_empty() {
162                    state.request_compilation(u.name.clone(), true, vec![], symtab, None);
163                }
164            }
165            ExecutableItem::StructInstance => {}
166            ExecutableItem::EnumInstance { .. } => {}
167            ExecutableItem::ExternUnit(_, _) => {}
168        }
169    }
170
171    let mut body_replacements: HashMap<NameID, LambdaReplacement> = HashMap::new();
172
173    let mut result = vec![];
174    'item_loop: while let Some(item) = state.next_target() {
175        let original_item = items.get(&item.source_name.inner);
176
177        let mut reg_name_map = BTreeMap::new();
178        match original_item {
179            Some((ExecutableItem::Unit(u), old_type_state)) => {
180                let (u, old_type_state) =
181                    if let Some(replacement) = body_replacements.get(&u.name.name_id().inner) {
182                        let new_unit = match replacement.replace_in(u.clone(), idtracker) {
183                            Ok(u) => u,
184                            Err(e) => {
185                                result.push(Err(state.add_mono_traceback(e, &item)));
186                                break 'item_loop;
187                            }
188                        };
189
190                        (&new_unit.clone(), &{
191                            let mut type_state = impl_type_state.create_child();
192                            let type_ctx = &spade_typeinference::Context {
193                                symtab: symtab.symtab(),
194                                items: item_list,
195                                trait_impls: &old_type_state.trait_impls,
196                            };
197                            let unification_result = type_state.visit_unit_with_preprocessing(
198                                &new_unit,
199                                |type_state, unit, generic_list, ctx| {
200                                    let gl = type_state
201                                        .get_generic_list(generic_list)
202                                        .ok_or_else(|| {
203                                            diag_anyhow!(unit, "Did not have a generic list")
204                                        })?
205                                        .clone();
206                                    for (i, (_, ty)) in replacement.arguments.iter().enumerate() {
207                                        let old_ty = gl
208                                            .get(&unit.head.get_type_params()[i].name_id)
209                                            .ok_or_else(|| {
210                                                diag_anyhow!(
211                                                    unit,
212                                                    "Did not have an entry for argument {i}"
213                                                )
214                                            })?
215                                            .clone();
216
217                                        ty.insert(type_state)
218                                            .unify_with(&old_ty, type_state)
219                                            .commit(type_state, ctx)
220                                            .into_default_diagnostic(unit, type_state)?;
221                                    }
222                                    Ok(())
223                                },
224                                type_ctx,
225                            );
226                            if let Err(e) = unification_result {
227                                result.push(Err(state.add_mono_traceback(e, &item)));
228                                break 'item_loop;
229                            }
230
231                            type_state
232                        })
233                    } else {
234                        (u, old_type_state)
235                    };
236
237                let type_ctx = &spade_typeinference::Context {
238                    symtab: symtab.symtab(),
239                    items: item_list,
240                    trait_impls: &old_type_state.trait_impls,
241                };
242                let mut type_state = old_type_state.create_child();
243                let generic_list_token = if !u.head.get_type_params().is_empty() {
244                    Some(GenericListToken::Definition(u.name.name_id().inner.clone()))
245                } else {
246                    None
247                };
248
249                if let Some(generic_list_token) = &generic_list_token {
250                    let generic_list = type_state
251                        .get_generic_list(generic_list_token)
252                        .expect("Found no generic list  when monomorphizing")
253                        .clone();
254                    for (source_param, new) in
255                        u.head.get_type_params().iter().zip(item.params.iter())
256                    {
257                        let source_var = &generic_list[&source_param.name_id()];
258
259                        type_state
260                            .trace_stack
261                            .push(TraceStackEntry::Message(format!(
262                                "Performing mono replacement of {source_var} -> {new:?}",
263                                source_var = source_var.debug_resolve(&type_state),
264                            )));
265
266                        let tvar = new.insert(&mut type_state);
267                        match type_state
268                            .unify(&tvar, source_var, type_ctx)
269                            .into_default_diagnostic(u, &type_state)
270                            .and_then(|_| type_state.check_requirements(true, type_ctx))
271                        {
272                            Ok(_) => {}
273                            Err(e) => {
274                                result.push(Err(state.add_mono_traceback(e, &item)));
275                                continue 'item_loop;
276                            }
277                        }
278                    }
279
280                    if std::env::var("SPADE_TRACE_TYPEINFERENCE").is_ok() {
281                        println!(
282                            "After mono of {} replacing {} with {}",
283                            u.inner.name,
284                            u.head
285                                .get_type_params()
286                                .iter()
287                                .map(|p| format!("{p:?}"))
288                                .join(", "),
289                            item.params.iter().map(|p| format!("{p}")).join(", ")
290                        );
291                        type_state.print_equations();
292                        println!("{}", format_trace_stack(&type_state));
293                    }
294                }
295
296                // Apply passes to the type checked module
297                let mut u = u.clone();
298                let passes = [
299                    &mut LowerLambdaDefs {
300                        type_state: &type_state,
301                        replacements: &mut body_replacements,
302                    } as &mut dyn Pass,
303                    &mut FlattenRegs {
304                        type_state: &type_state,
305                        items: item_list,
306                        symtab,
307                    } as &mut dyn Pass,
308                    &mut LowerMethods {
309                        type_state: &type_state,
310                        items: item_list,
311                        symtab,
312                    } as &mut dyn Pass,
313                    &mut LowerTypeLevelIf {
314                        type_state: &type_state,
315                        items: item_list,
316                        symtab,
317                        allowed_ids: Default::default(),
318                    } as &mut dyn Pass,
319                    &mut InOutChecks {
320                        type_state: &type_state,
321                        items: item_list,
322                        symtab,
323                    } as &mut dyn Pass,
324                    &mut DisallowZeroSize {
325                        type_state: &type_state,
326                        items: item_list,
327                        symtab,
328                    } as &mut dyn Pass,
329                ];
330                for pass in passes {
331                    let pass_result = u.apply(pass);
332                    if let Err(e) = pass_result {
333                        result.push(Err(e));
334                        continue 'item_loop;
335                    }
336                }
337
338                let self_mono_item = Some(item.clone());
339                let out = generate_unit(
340                    &u.inner,
341                    item.new_name.clone(),
342                    &mut type_state,
343                    symtab,
344                    idtracker,
345                    item_list,
346                    &generic_list_token,
347                    &mut reg_name_map,
348                    &mut state,
349                    diag_handler,
350                    name_source_map,
351                    self_mono_item,
352                    opt_passes,
353                )
354                .map_err(|e| state.add_mono_traceback(e, &item))
355                .map(|mir| MirOutput {
356                    mir,
357                    type_state: type_state.clone(),
358                    reg_name_map,
359                });
360                result.push(out);
361            }
362            Some((ExecutableItem::StructInstance, _)) => {
363                panic!("Requesting compilation of struct instance as module")
364            }
365            Some((ExecutableItem::EnumInstance { .. }, _)) => {
366                panic!("Requesting compilation of enum instance as module")
367            }
368            Some((ExecutableItem::ExternUnit(_, _), _)) => {
369                panic!("Requesting compilation of extern unit")
370            }
371            None => {}
372        }
373    }
374    result
375}