spade_mir/
verilator_wrapper.rs

1//! This module generates wrappers around verilator modules which are used
2//! with spade-cxx to feed values to and from verilated units
3
4use itertools::Itertools;
5use nesty::{code, Code};
6use num::ToPrimitive;
7use spade_common::num_ext::InfallibleToBigUint;
8
9use crate::{codegen::mangle_input, types::Type, unit_name::UnitNameKind, Entity};
10
11impl Type {
12    fn output_wrappers(
13        &self,
14        root_class: &str,
15        path: Vec<&str>,
16        class_name: &str,
17    ) -> (String, String) {
18        if self.size() == 0u32.to_biguint() {
19            return (String::new(), String::new());
20        }
21        let (field_decls, field_impls, field_members, field_constructor_calls) = match self {
22            Type::Struct(fields) => fields
23                .iter()
24                .map(|(name, ty)| {
25                    let mut path = path.clone();
26                    path.push(name);
27                    let field_class_name = format!("{class_name}_{name}");
28                    let (decls, impls) = ty.output_wrappers(root_class, path, &field_class_name);
29
30                    let member = format!("{field_class_name}* {name};");
31
32                    let constructor = format!(", {name}(init_{field_class_name}(root))");
33
34                    (decls, impls, member, constructor)
35                })
36                .multiunzip(),
37            _ => (
38                "".to_string(),
39                "".to_string(),
40                "".to_string(),
41                "".to_string(),
42            ),
43        };
44
45        let field_as_strings = path.iter().map(|p| format!(r#""{p}""#)).join(", ");
46        let declaration = code! {
47            [0] field_decls;
48            [0] format!("class {class_name};");
49            [0] format!("{class_name}* init_{class_name}({root_class}* root);");
50        }
51        .to_string();
52
53        let implementation = code! {
54            [0] field_impls;
55            [0] format!("class {class_name} {{");
56            [1]     "public:";
57            [2]         format!("{class_name}({root_class}* root)");
58            [4]         format!(": root(root) ");
59            [4]         field_constructor_calls;
60            [2]         "{}";
61            [4]         format!("{root_class}* root;");
62            [2]         "bool operator==(std::string const& other) const {";
63            [3]             format!(r#"auto field = root->s_ext->output_field({{{field_as_strings}}});"#);
64            [3]             format!("auto val = spade::new_bit_string(root->output_string_fn());");
65            [3]             format!(r#"return root"#);
66            [3]             format!(r#"         ->s_ext"#);
67            [3]             format!(r#"         ->compare_field(*field, other, *val)"#);
68            [3]             format!(r#"         ->matches();"#);
69            [2]         "}";
70            [2]         format!("void assert_eq(std::string const& expected, std::string const& source_loc) {{");
71            [3]             format!(r#"auto field = root->s_ext->output_field({{{field_as_strings}}});"#);
72            [3]             format!("auto val = spade::new_bit_string(root->output_string_fn());");
73            [3]             format!(r#"root"#);
74            [3]             format!(r#"    ->s_ext"#);
75            [3]             format!(r#"    ->assert_eq(*field, expected, *val, source_loc);"#);
76            [2]         "}";
77            [2]         format!("std::string spade_repr() {{");
78            [3]             format!(r#"auto field = root->s_ext->output_field({{{field_as_strings}}});"#);
79            [3]             format!("auto val = spade::new_bit_string(root->output_string_fn());");
80            [3]             format!(r#"return std::string(root"#);
81            [3]             format!(r#"    ->s_ext"#);
82            [3]             format!(r#"    ->field_value(*field, *val));"#);
83            [2]         "}";
84            [2]         field_members;
85            [0] "};";
86            [0] format!("{class_name}* init_{class_name}({root_class}* root) {{");
87            [1]     format!("return new {class_name}(root);");
88            [0] "}";
89        }.to_string();
90
91        (declaration, implementation)
92    }
93}
94
95impl Entity {
96    fn input_wrapper(&self, parent_class_name: &str) -> (String, String) {
97        let class_name = format!("{parent_class_name}_i");
98
99        let (constructor_calls, fields_in_parent, field_classes): (Vec<_>, Vec<_>, Vec<_>) = self
100            .inputs
101            .iter()
102            .filter(|f| f.ty.size() != 0u32.to_biguint())
103            .map(|f| {
104                let field_name = &f.name;
105                let field_name_mangled = mangle_input(&f.no_mangle, &f.name);
106                let field_class_name = format!("{class_name}_{field_name}");
107
108                let assignment = if f.ty.size() <= 64u32.to_biguint() {
109                    format!("parent.top->{field_name_mangled} = value->as_u64();")
110                } else {
111                    let size_u64 = f.ty.size().to_u64().expect("Input size does not fit in u64");
112                    code! {
113                        [0] "auto value_split = value->as_u32_chunks();";
114                        [0] (0..(size_u64 / 32))
115                            .map(|i| format!("parent.top->{field_name_mangled}[{i}] = value_split[{i}];"))
116                            .join("\n");
117                        [0] if size_u64 % 32 != 0 {
118                                let idx = size_u64 / 32;
119                                vec![format!("parent.top->{field_name_mangled}[{idx}] = value_split[{idx}];")]
120                            } else {
121                                vec![]
122                            }
123                    }.to_string()
124                };
125
126                let class = code! {
127                    [0] format!("class {field_class_name} {{");
128                    [1]     "public:";
129                    [2]         format!("{field_class_name}({parent_class_name}& parent)");
130                    [3]             ": parent(parent)";
131                    [2]         "{}";
132                    [2]         format!("{field_class_name}& operator=(std::string const& val) {{");
133                    [3]             format!(r#"auto value = parent.s_ext->port_value("{field_name}", val);"#);
134                    [3]             assignment;
135                    [3]             "return *this;";
136                    [2]         "}";
137                    [1]     "private:";
138                    [2]         format!("{parent_class_name}& parent;");
139                    [0] "};"
140                }.to_string();
141
142                let constructor_call = format!(", {field_name}(parent)");
143                let field = format!("{field_class_name} {field_name};");
144
145                (constructor_call, field, class)
146            })
147            .multiunzip();
148
149        let constructor = code! {
150            [0] format!("{class_name}({parent_class_name}& parent)");
151            [1]     ": parent(parent)";
152            [1]     constructor_calls;
153            [0] "{}"
154        };
155
156        let pre_declaration = code! {
157            [0] format!("class {class_name};");
158            // NOTE: Ugly hack to avoid having to generate both a cpp and hpp file while retaining
159            // a loop in the 'dependency graph'. We'll define a free standing function to
160            // initialize the input
161            [0] format!("{class_name}* init_{class_name}({parent_class_name}& t);");
162        }
163        .to_string();
164        let implementation = code! {
165            [0] field_classes;
166            [0] format!("class {class_name} {{");
167            [1]     "public:";
168            [2]         constructor;
169            [2]         fields_in_parent;
170            [1]     "private:";
171            [2]         format!("{parent_class_name}& parent;");
172            [0] "};";
173            [0] format!("{class_name}* init_{class_name}({parent_class_name}& t) {{");
174            [1]     format!("return new {class_name}(t);");
175            [0] "}"
176        }
177        .to_string();
178
179        (pre_declaration, implementation)
180    }
181
182    pub fn verilator_wrapper(&self) -> Option<String> {
183        // Units which are mangled have no stable name in verilator, so we won't generate
184        // them
185        let name = match &self.name.kind {
186            UnitNameKind::Unescaped(name) => name,
187            UnitNameKind::Escaped { .. } => return None,
188        };
189
190        let class_name = format!("{name}_spade_t");
191        let output_class_name = format!("{class_name}_o");
192
193        let has_output = self.output_type.size() != 0u32.to_biguint();
194
195        let constructor = code! {
196            [0] format!("{class_name}(std::string spade_state, std::string spade_top, V{name}* top)");
197            [1]     ": s_ext(spade::setup_spade(spade_top, spade_state))";
198            [1]     ", top(top)";
199            [1]     format!(", i(init_{class_name}_i(*this))");
200            [1]     if has_output {format!(", o(init_{class_name}_o(this))")} else {String::new()};
201            [0]  "{";
202            [0] "}";
203        };
204
205        let (output_declaration, output_impl) =
206            self.output_type
207                .output_wrappers(&class_name, vec![], &output_class_name);
208
209        let size = self.output_type.size();
210        let size_u64 = self
211            .output_type
212            .size()
213            .to_u64()
214            .expect("Output size does not fit in 64 bits");
215        let output_string_generator = if !has_output {
216            code! {}
217        } else if size <= 64u32.to_biguint() {
218            code! {
219                [0] format!("std::bitset<{size}> bits = this->top->output___05F;");
220                [0] "std::stringstream ss;";
221                [0] "ss << bits;";
222                [0] "return ss.str();";
223            }
224        } else {
225            code! {
226                [0] "std::bitset<32> bits;";
227                [0] "std::stringstream ss;";
228                [0] if size_u64 % 32 != 0 {
229                        code!{
230                            [0] format!("std::bitset<{}> bits_;", size_u64 % 32);
231                            [0] format!("bits_ = this->top->output___05F[{}];", size_u64 / 32);
232                            [0] format!("ss << bits_;")
233                        }
234                    }
235                    else {
236                        code!{}
237                    };
238                [0] (0..(size / 32u32).to_u64().unwrap())
239                    .rev()
240                    .map(|i| {
241                        code! {
242                            [0] format!("bits = this->top->output___05F[{i}];");
243                            [0] format!("ss << bits;")
244                        }.to_string()
245                    })
246                    .join("\n");
247                [0] "return ss.str();";
248            }
249        };
250
251        let output_string_fn = code! {
252            [0] "std::string output_string_fn() {";
253            [0]     output_string_generator;
254            [0] "}";
255        };
256
257        let (input_pre, input_impl) = self.input_wrapper(&class_name);
258        let class = code! {
259            [0] "#include <sstream>";
260            [0] "#include <bitset>";
261            [0] format!("#if __has_include(<V{name}.h>)");
262            [0] format!(r#"#include <V{name}.h>"#);
263            [0] format!("class {class_name};");
264            [0] input_pre;
265            [0] output_declaration;
266            [0] format!("class {class_name} {{");
267            [1]     "public:";
268            [2]         constructor;
269            [2]         format!("{class_name}_i* i;");
270            [2]         if has_output {
271                            format!("{class_name}_o* o;")
272                        } else {
273                            String::new()
274                        };
275            [2]         "rust::Box<spade::SimulationExt> s_ext;";
276            [2]         format!("V{name}* top;");
277            [2]         output_string_fn;
278            [0] "};";
279            [0] input_impl;
280            [0] output_impl;
281            [0] "#endif";
282        };
283
284        Some(class.to_string())
285    }
286}
287
288pub fn verilator_wrappers(entities: &[&Entity]) -> String {
289    let inner = entities
290        .iter()
291        .filter_map(|e| Entity::verilator_wrapper(e))
292        .collect::<Vec<_>>();
293
294    code! {
295        [0] "#pragma once";
296        [0] "#include <string>";
297        [0] inner
298    }
299    .to_string()
300}