1use std::collections::{BTreeMap, HashMap, VecDeque};
2
3use itertools::Itertools;
4use mir::passes::MirPass;
5use spade_common::location_info::Loc;
6use spade_common::{id_tracker::ExprIdTracker, location_info::WithLocation, name::NameID};
7use spade_diagnostics::diagnostic::{Message, Subdiagnostic};
8use spade_diagnostics::{diag_anyhow, DiagHandler, Diagnostic};
9use spade_hir::{symbol_table::FrozenSymtab, ExecutableItem, ItemList, UnitName};
10use spade_mir as mir;
11use spade_typeinference::equation::KnownTypeVar;
12use spade_typeinference::error::UnificationErrorExt;
13use spade_typeinference::trace_stack::{format_trace_stack, TraceStackEntry};
14use spade_typeinference::{GenericListToken, HasType, TypeState};
15
16use crate::error::Result;
17use crate::generate_unit;
18use crate::name_map::NameSourceMap;
19use crate::passes::disallow_inout_bindings::InOutChecks;
20use crate::passes::disallow_zero_size::DisallowZeroSize;
21use crate::passes::flatten_regs::FlattenRegs;
22use crate::passes::lower_lambda_defs::{LambdaReplacement, LowerLambdaDefs};
23use crate::passes::lower_methods::LowerMethods;
24use crate::passes::lower_type_level_if::LowerTypeLevelIf;
25use crate::passes::pass::{Pass, Passable};
26
27#[derive(Clone, Hash, PartialEq, Eq)]
29pub struct MonoItem {
30 pub source_name: Loc<NameID>,
32 pub new_name: UnitName,
34 pub params: Vec<KnownTypeVar>,
37}
38
39pub struct MonoState {
40 to_compile: VecDeque<MonoItem>,
42 translation: BTreeMap<(NameID, Vec<KnownTypeVar>), NameID>,
44 request_points: HashMap<MonoItem, Option<(MonoItem, Loc<()>)>>,
47}
48
49impl Default for MonoState {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl MonoState {
56 pub fn new() -> MonoState {
57 MonoState {
58 to_compile: VecDeque::new(),
59 translation: BTreeMap::new(),
60 request_points: HashMap::new(),
61 }
62 }
63
64 pub fn request_compilation(
68 &mut self,
69 source_name: UnitName,
70 reuse_nameid: bool,
71 params: Vec<KnownTypeVar>,
72 symtab: &mut FrozenSymtab,
73 request_point: Option<(MonoItem, Loc<()>)>,
74 ) -> NameID {
75 match self
76 .translation
77 .get(&(source_name.name_id().inner.clone(), params.clone()))
78 {
79 Some(prev) => prev.clone(),
80 None => {
81 let new_name = if reuse_nameid {
82 source_name.name_id().inner.clone()
83 } else {
84 symtab.new_name(source_name.name_id().1.clone())
85 };
86
87 let new_unit_name = match &source_name {
90 UnitName::WithID(_) => UnitName::WithID(new_name.clone().nowhere()),
91 UnitName::FullPath(_) => UnitName::FullPath(new_name.clone().nowhere()),
92 UnitName::Unmangled(source, _) => {
93 UnitName::Unmangled(source.clone(), new_name.clone().nowhere())
94 }
95 };
96
97 let item = MonoItem {
98 source_name: source_name.name_id().clone(),
99 new_name: new_unit_name,
100 params: params.clone(),
101 };
102 self.request_points.insert(item.clone(), request_point);
103
104 self.translation.insert(
105 (source_name.name_id().inner.clone(), params.clone()),
106 new_name.clone(),
107 );
108 self.to_compile.push_back(item);
109 new_name
110 }
111 }
112 }
113
114 fn next_target(&mut self) -> Option<MonoItem> {
115 self.to_compile.pop_front()
116 }
117
118 fn add_mono_traceback(&self, diagnostic: Diagnostic, item: &MonoItem) -> Diagnostic {
119 let parent = self.request_points.get(item).and_then(|x| x.clone());
120 if let Some((next_parent, loc)) = parent {
121 let generic_params = item.params.iter().map(|p| format!("{p}")).join(", ");
122
123 let new = diagnostic.subdiagnostic(Subdiagnostic::TemplateTraceback {
124 span: loc.into(),
125 message: Message::from(format!("{}<{}>", item.source_name, generic_params)),
126 });
127
128 self.add_mono_traceback(new, &next_parent)
129 } else {
130 diagnostic
131 }
132 }
133}
134
135pub struct MirOutput {
136 pub mir: mir::Entity,
137 pub type_state: TypeState,
138 pub reg_name_map: BTreeMap<NameID, NameID>,
141}
142
143pub fn compile_items(
144 items: &BTreeMap<&NameID, (&ExecutableItem, TypeState)>,
145 symtab: &mut FrozenSymtab,
146 idtracker: &mut ExprIdTracker,
147 name_source_map: &mut NameSourceMap,
148 item_list: &ItemList,
149 diag_handler: &mut DiagHandler,
150 opt_passes: &[&dyn MirPass],
151 impl_type_state: &TypeState,
152) -> Vec<Result<MirOutput>> {
153 let mut state = MonoState::new();
157
158 for (item, _) in items.values() {
159 match item {
160 ExecutableItem::Unit(u) => {
161 if u.head.get_type_params().is_empty() {
162 state.request_compilation(u.name.clone(), true, vec![], symtab, None);
163 }
164 }
165 ExecutableItem::StructInstance => {}
166 ExecutableItem::EnumInstance { .. } => {}
167 ExecutableItem::ExternUnit(_, _) => {}
168 }
169 }
170
171 let mut body_replacements: HashMap<NameID, LambdaReplacement> = HashMap::new();
172
173 let mut result = vec![];
174 'item_loop: while let Some(item) = state.next_target() {
175 let original_item = items.get(&item.source_name.inner);
176
177 let mut reg_name_map = BTreeMap::new();
178 match original_item {
179 Some((ExecutableItem::Unit(u), old_type_state)) => {
180 let (u, old_type_state) =
181 if let Some(replacement) = body_replacements.get(&u.name.name_id().inner) {
182 let new_unit = match replacement.replace_in(u.clone(), idtracker) {
183 Ok(u) => u,
184 Err(e) => {
185 result.push(Err(state.add_mono_traceback(e, &item)));
186 break 'item_loop;
187 }
188 };
189
190 (&new_unit.clone(), &{
191 let mut type_state = impl_type_state.create_child();
192 let type_ctx = &spade_typeinference::Context {
193 symtab: symtab.symtab(),
194 items: item_list,
195 trait_impls: &old_type_state.trait_impls,
196 };
197 let unification_result = type_state.visit_unit_with_preprocessing(
198 &new_unit,
199 |type_state, unit, generic_list, ctx| {
200 let gl = type_state
201 .get_generic_list(generic_list)
202 .ok_or_else(|| {
203 diag_anyhow!(unit, "Did not have a generic list")
204 })?
205 .clone();
206 for (i, (_, ty)) in replacement.arguments.iter().enumerate() {
207 let old_ty = gl
208 .get(&unit.head.get_type_params()[i].name_id)
209 .ok_or_else(|| {
210 diag_anyhow!(
211 unit,
212 "Did not have an entry for argument {i}"
213 )
214 })?
215 .clone();
216
217 ty.insert(type_state)
218 .unify_with(&old_ty, type_state)
219 .commit(type_state, ctx)
220 .into_default_diagnostic(unit, type_state)?;
221 }
222 Ok(())
223 },
224 type_ctx,
225 );
226 if let Err(e) = unification_result {
227 result.push(Err(state.add_mono_traceback(e, &item)));
228 break 'item_loop;
229 }
230
231 type_state
232 })
233 } else {
234 (u, old_type_state)
235 };
236
237 let type_ctx = &spade_typeinference::Context {
238 symtab: symtab.symtab(),
239 items: item_list,
240 trait_impls: &old_type_state.trait_impls,
241 };
242 let mut type_state = old_type_state.create_child();
243 let generic_list_token = if !u.head.get_type_params().is_empty() {
244 Some(GenericListToken::Definition(u.name.name_id().inner.clone()))
245 } else {
246 None
247 };
248
249 if let Some(generic_list_token) = &generic_list_token {
250 let generic_list = type_state
251 .get_generic_list(generic_list_token)
252 .expect("Found no generic list when monomorphizing")
253 .clone();
254 for (source_param, new) in
255 u.head.get_type_params().iter().zip(item.params.iter())
256 {
257 let source_var = &generic_list[&source_param.name_id()];
258
259 type_state
260 .trace_stack
261 .push(TraceStackEntry::Message(format!(
262 "Performing mono replacement of {source_var} -> {new:?}",
263 source_var = source_var.debug_resolve(&type_state),
264 )));
265
266 let tvar = new.insert(&mut type_state);
267 match type_state
268 .unify(&tvar, source_var, type_ctx)
269 .into_default_diagnostic(u, &type_state)
270 .and_then(|_| type_state.check_requirements(true, type_ctx))
271 {
272 Ok(_) => {}
273 Err(e) => {
274 result.push(Err(state.add_mono_traceback(e, &item)));
275 continue 'item_loop;
276 }
277 }
278 }
279
280 if std::env::var("SPADE_TRACE_TYPEINFERENCE").is_ok() {
281 println!(
282 "After mono of {} replacing {} with {}",
283 u.inner.name,
284 u.head
285 .get_type_params()
286 .iter()
287 .map(|p| format!("{p:?}"))
288 .join(", "),
289 item.params.iter().map(|p| format!("{p}")).join(", ")
290 );
291 type_state.print_equations();
292 println!("{}", format_trace_stack(&type_state));
293 }
294 }
295
296 let mut u = u.clone();
298 let passes = [
299 &mut LowerLambdaDefs {
300 type_state: &type_state,
301 replacements: &mut body_replacements,
302 } as &mut dyn Pass,
303 &mut FlattenRegs {
304 type_state: &type_state,
305 items: item_list,
306 symtab,
307 } as &mut dyn Pass,
308 &mut LowerMethods {
309 type_state: &type_state,
310 items: item_list,
311 symtab,
312 } as &mut dyn Pass,
313 &mut LowerTypeLevelIf {
314 type_state: &type_state,
315 items: item_list,
316 symtab,
317 allowed_ids: Default::default(),
318 } as &mut dyn Pass,
319 &mut InOutChecks {
320 type_state: &type_state,
321 items: item_list,
322 symtab,
323 } as &mut dyn Pass,
324 &mut DisallowZeroSize {
325 type_state: &type_state,
326 items: item_list,
327 symtab,
328 } as &mut dyn Pass,
329 ];
330 for pass in passes {
331 let pass_result = u.apply(pass);
332 if let Err(e) = pass_result {
333 result.push(Err(e));
334 continue 'item_loop;
335 }
336 }
337
338 let self_mono_item = Some(item.clone());
339 let out = generate_unit(
340 &u.inner,
341 item.new_name.clone(),
342 &mut type_state,
343 symtab,
344 idtracker,
345 item_list,
346 &generic_list_token,
347 &mut reg_name_map,
348 &mut state,
349 diag_handler,
350 name_source_map,
351 self_mono_item,
352 opt_passes,
353 )
354 .map_err(|e| state.add_mono_traceback(e, &item))
355 .map(|mir| MirOutput {
356 mir,
357 type_state: type_state.clone(),
358 reg_name_map,
359 });
360 result.push(out);
361 }
362 Some((ExecutableItem::StructInstance, _)) => {
363 panic!("Requesting compilation of struct instance as module")
364 }
365 Some((ExecutableItem::EnumInstance { .. }, _)) => {
366 panic!("Requesting compilation of enum instance as module")
367 }
368 Some((ExecutableItem::ExternUnit(_, _), _)) => {
369 panic!("Requesting compilation of extern unit")
370 }
371 None => {}
372 }
373 }
374 result
375}