spade_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Literal, Span, TokenStream as TokenStream2};
3use quote::{quote, ToTokens};
4use syn::parse::{Nothing, Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::{parse_macro_input, parse_quote};
7use syn::{Error, Expr, Fields, FieldsNamed, Ident, ImplItemFn, ItemStruct, Token};
8
9// Thanks to discord user Yandros(MemeOverloard) for doing the bulk of the work with this
10// macro
11#[proc_macro_attribute]
12pub fn trace_parser(attrs: TokenStream, input: TokenStream) -> TokenStream {
13    parse_macro_input!(attrs as Nothing);
14    let mut input = parse_macro_input!(input as ImplItemFn);
15    let block = &mut input.block;
16
17    let function_name = format!("{}", input.sig.ident);
18
19    *block = parse_quote!({
20        self.parse_stack.push(ParseStackEntry::Enter(#function_name.to_string()));
21        let ret: Result<_> = (|| #block)();
22        if let Err(e) = &ret {
23            self.parse_stack.push(ParseStackEntry::ExitWithDiagnostic(e.clone()));
24        }
25        else {
26            self.parse_stack.push(ParseStackEntry::Exit);
27        }
28        ret
29    });
30    input.into_token_stream().into()
31}
32
33#[proc_macro_attribute]
34pub fn trace_typechecker(attrs: TokenStream, input: TokenStream) -> TokenStream {
35    parse_macro_input!(attrs as Nothing);
36    let mut input = parse_macro_input!(input as ImplItemFn);
37    let block = &mut input.block;
38
39    let function_name = format!("{}", input.sig.ident);
40
41    *block = parse_quote!({
42        self.trace_stack.push(TraceStackEntry::Enter(#function_name.to_string()));
43        let ret: Result<_> = (|| #block)();
44        self.trace_stack.push(TraceStackEntry::Exit);
45        ret
46    });
47    input.into_token_stream().into()
48}
49
50enum DiagnosticMessage {
51    /// `"message"`
52    Literal(Literal),
53    /// `"literal containing {} or more expressions as format arguments", 1`
54    Formatted(Literal, Punctuated<Expr, Token![,]>),
55}
56
57impl Parse for DiagnosticMessage {
58    fn parse(input: ParseStream) -> syn::Result<Self> {
59        let first = input.parse()?;
60        let _: Option<Token![,]> = input.parse()?;
61        if input.is_empty() {
62            return Ok(DiagnosticMessage::Literal(first));
63        }
64        let rest = Punctuated::parse_terminated(input)?;
65        Ok(DiagnosticMessage::Formatted(first, rest))
66    }
67}
68
69impl DiagnosticMessage {
70    fn quote(&self) -> TokenStream2 {
71        match self {
72            DiagnosticMessage::Literal(lit) => quote!( #lit ),
73            DiagnosticMessage::Formatted(lit, rest) => quote!(format!(#lit, #rest)),
74        }
75    }
76}
77
78/// E.g. `primary, "Expected {} arguments, got {}", diag.expected, diag.got`
79struct DiagnosticAttribute {
80    ident: Ident,
81    message: Option<DiagnosticMessage>,
82}
83
84impl Parse for DiagnosticAttribute {
85    fn parse(input: ParseStream) -> syn::Result<Self> {
86        let ident = input.parse()?;
87        if input.is_empty() {
88            return Ok(DiagnosticAttribute {
89                ident,
90                message: None,
91            });
92        }
93        let _: Token![,] = input.parse()?;
94        let message = Some(input.parse()?);
95        Ok(DiagnosticAttribute { ident, message })
96    }
97}
98
99fn field_attributes(fields: &FieldsNamed) -> Result<Vec<(&Ident, DiagnosticAttribute)>, Error> {
100    fields
101        .named
102        .iter()
103        .filter_map(|field| {
104            field.ident.as_ref().map(|field_ident| {
105                // Zip the attributes together with the field they're on.
106                std::iter::zip(
107                    std::iter::repeat(field_ident),
108                    // Only the #[diagnostic]-attributes
109                    field.attrs.iter().filter(|attr| {
110                        attr.path()
111                            .get_ident()
112                            .map(|ident| ident == "diagnostic")
113                            .unwrap_or(false)
114                    }),
115                )
116            })
117        })
118        .flatten()
119        .map(|(field, attr)| match attr.parse_args() {
120            Ok(attr) => Ok((field, attr)),
121            Err(_) => {
122                Err(Error::new_spanned(attr, "inner attribute is malformed\nexpected #[diagnostic(<primary/secondary>, <MESSAGE...>)]"))
123            }
124        })
125        .collect()
126}
127
128/// Expected usage:
129///
130/// ```ignore
131/// #[derive(IntoDiagnostic, Clone)]
132/// #[diagnostic(error, "Expected argument list")]
133/// pub(crate) struct ExpectedArgumentList {
134///     #[diagnostic(primary, "Expected argument list for this instantiation")]
135///     pub base_expr: Loc<()>,
136///     pub next_token: Token,
137/// }
138/// ```
139fn actual_derive_diagnostic(input: ItemStruct) -> Result<TokenStream, Error> {
140    let fields = match &input.fields {
141        Fields::Named(fields) => fields,
142        Fields::Unnamed(_) | Fields::Unit => {
143            return Err(Error::new(
144                Span::call_site(),
145                "Can only derive IntoDiagnostic on structs with named fields",
146            ));
147        }
148    };
149    let ident = input.ident;
150
151    // Get the top attribute: `#[diagnostic(error, "...")]`
152    let top_attribute = input
153        .attrs
154        .iter()
155        .find(|attr| {
156            attr.path()
157                .get_ident()
158                .map(|ident| ident == "diagnostic")
159                .unwrap_or(false)
160        })
161        .ok_or_else(|| Error::new(Span::call_site(), "missing outer #[diagnostic] attribute"))?;
162    let DiagnosticAttribute {
163        ident: level,
164        message: primary_message,
165    } = top_attribute
166        .parse_args()
167        .map_err(|_| Error::new_spanned(top_attribute, "top attribute is malformed\nexpected something like `#[diagnostic(error, \"uh oh, stinky\")]`"))?;
168    let primary_message = primary_message.map(|msg| msg.quote());
169
170    // Look for field attributes (`#[diagnostic(primary, "...")]`), as a vec of
171    // idents and their diagnostic attribute. Only diagnostic attributes are
172    // handled.
173    let attrs = field_attributes(fields)?;
174    let primary = attrs
175        .iter()
176        .find(|(_, attr)| attr.ident == "primary")
177        .ok_or_else(|| Error::new(Span::call_site(), "primary span is required"))?;
178    let primary_span = primary.0;
179    let primary_label = primary
180        .1
181        .message
182        .as_ref()
183        .map(DiagnosticMessage::quote)
184        .map(|msg| quote!( .primary_label(#msg) ))
185        .unwrap_or_default();
186    let secondary_labels = attrs
187        .iter()
188        .filter(|(_, attr)| attr.ident == "secondary")
189        .map(|(field, attr)| -> Result<_, Error> {
190            let message = attr
191                .message
192                .as_ref()
193                .map(DiagnosticMessage::quote)
194                .ok_or_else(|| {
195                    Error::new(Span::call_site(), "secondary spans require a message")
196                })?;
197            Ok(quote!( .secondary_label(diag.#field, #message) ))
198        })
199        .collect::<Result<Vec<_>, _>>()?;
200
201    // Generate the code, with safe paths.
202    Ok(quote! {
203        impl std::convert::From<#ident> for ::spade_diagnostics::Diagnostic {
204            fn from(diag: #ident) -> Self {
205                ::spade_diagnostics::Diagnostic::#level(
206                    diag.#primary_span,
207                    #primary_message
208                )
209                #primary_label
210                #(#secondary_labels)*
211            }
212        }
213    }
214    .into())
215}
216
217#[proc_macro_derive(IntoDiagnostic, attributes(diagnostic))]
218pub fn derive_diagnostic(input: TokenStream) -> TokenStream {
219    match syn::parse2(input.into())
220        .and_then(actual_derive_diagnostic)
221        .map_err(|e| e.into_compile_error().into())
222    {
223        Ok(ts) | Err(ts) => ts,
224    }
225}
226
227/// Expected usage:
228///
229/// ```ignore
230/// #[derive(IntoSubdiagnostic)]
231/// #[diagnostic(suggestion, "Use `{` if you want to add items to this enum variant")]
232/// pub(crate) struct SuggestBraceEnumVariant {
233///     #[diagnostic(replace, "{")]
234///     pub open_paren: Loc<()>,
235///     #[diagnostic(replace, "}")]
236///     pub close_paren: Loc<()>,
237/// }
238/// ```
239fn actual_derive_subdiagnostic(input: ItemStruct) -> Result<TokenStream, Error> {
240    let fields = match &input.fields {
241        Fields::Named(fields) => fields,
242        Fields::Unnamed(_) | Fields::Unit => {
243            return Err(Error::new(
244                Span::call_site(),
245                "Can only derive IntoSubdiagnostic on structs with named fields",
246            ));
247        }
248    };
249    let ident = input.ident;
250
251    // Get the top attribute: `#[diagnostic(suggestion, "...")]`
252    let top_attribute = input
253        .attrs
254        .iter()
255        .find(|attr| {
256            attr.path()
257                .get_ident()
258                .map(|ident| ident == "diagnostic")
259                .unwrap_or(false)
260        })
261        .ok_or_else(|| Error::new(Span::call_site(), "missing outer #[diagnostic] attribute"))?;
262    let DiagnosticAttribute {
263        ident: subdiag_kind,
264        message,
265    } = top_attribute
266        .parse_args()
267        .map_err(|_| Error::new_spanned(top_attribute, "top attribute is malformed\nexpected something like `#[diagnostic(suggestion, \"uh oh, stinky\")]`"))?;
268    let message = message
269        .as_ref()
270        .map(DiagnosticMessage::quote)
271        .unwrap_or(quote!(""));
272
273    // Look for field attributes (`#[diagnostic(replace, "...")]`), as a vec of
274    // idents and their diagnostic attribute. Only diagnostic attributes are
275    // handled.
276    let attrs = field_attributes(fields)?;
277    let tokens = match subdiag_kind.to_string().as_str() {
278        "suggestion" => {
279            let parts = attrs
280                .iter()
281                .map(|(field, attr)| {
282                    let replacement = attr
283                        .message
284                        .as_ref()
285                        .map(DiagnosticMessage::quote)
286                        .unwrap_or(quote!(""));
287                    match attr.ident.to_string().as_str() {
288                        "replace" => Ok(quote!((diag.#field.into(), #replacement.to_string()))),
289                        "insert_before" | "insert_after" | "remove" => todo!(),
290                        _ => Err(Error::new_spanned(
291                            &attr.ident,
292                            "unknown suggestion part kind",
293                        )),
294                    }
295                })
296                .collect::<Result<Vec<_>, _>>()?;
297            quote! {
298                impl std::convert::From<#ident> for ::spade_diagnostics::diagnostic::Subdiagnostic {
299                    fn from(diag: #ident) -> Self {
300                        ::spade_diagnostics::diagnostic::Subdiagnostic::Suggestion {
301                            parts: vec![#(#parts),*],
302                            message: #message.into(),
303                        }
304                    }
305                }
306            }
307        }
308        _ => {
309            return Err(Error::new_spanned(
310                subdiag_kind,
311                "unknown subdiagnostic kind",
312            ))
313        }
314    };
315    Ok(tokens.into())
316}
317
318#[proc_macro_derive(IntoSubdiagnostic, attributes(diagnostic))]
319pub fn derive_subdiagnostic(input: TokenStream) -> TokenStream {
320    match syn::parse2(input.into())
321        .and_then(actual_derive_subdiagnostic)
322        .map_err(|e| e.into_compile_error().into())
323    {
324        Ok(ts) | Err(ts) => ts,
325    }
326}
327
328#[cfg(test)]
329mod test {
330    mod into_diagnostic {
331        #[test]
332        fn ui() {
333            let t = trybuild::TestCases::new();
334            t.compile_fail("tests/into_diagnostic/ui/*.rs");
335        }
336    }
337
338    mod into_subdiagnostic {
339        #[test]
340        fn ui() {
341            let t = trybuild::TestCases::new();
342            t.compile_fail("tests/into_subdiagnostic/ui/*.rs");
343        }
344    }
345}