spade_mir/passes/
split_compound_regs.rs

1use num::ToPrimitive;
2
3use spade_common::id_tracker::ExprIdTracker;
4
5use crate::{types::Type, Binding, Operator, Register, Statement, ValueName};
6
7use super::MirPass;
8
9pub struct SplitCompoundRegs {}
10
11impl MirPass for SplitCompoundRegs {
12    fn name(&self) -> &'static str {
13        "split_compound_regs"
14    }
15
16    fn transform_statements(
17        &self,
18        stmts: &[Statement],
19        expr_idtracker: &mut ExprIdTracker,
20    ) -> Vec<Statement> {
21        stmts
22            .iter()
23            .flat_map(|stmt| match stmt {
24                Statement::Register(reg) => split_compound_reg(reg, expr_idtracker),
25                other => vec![other.clone()],
26            })
27            .collect()
28    }
29}
30
31fn generate_split_code(
32    reg: &Register,
33    members: &Vec<Type>,
34    expr_idtracker: &mut ExprIdTracker,
35) -> Vec<Statement> {
36    let (reg_names, split_stmts): (Vec<_>, Vec<_>) = members
37        .iter()
38        .enumerate()
39        .map(|(i, member)| {
40            let split_name = ValueName::Expr(expr_idtracker.next());
41            let reg_name = ValueName::Expr(expr_idtracker.next());
42            let split_stmt = Statement::Binding(Binding {
43                name: split_name.clone(),
44                operator: Operator::IndexTuple(i as u64, members.clone()),
45                operands: vec![reg.value.clone()],
46                ty: member.clone(),
47                loc: None,
48            });
49
50            let reg_stmts = split_compound_reg(
51                &Register {
52                    name: reg_name.clone(),
53                    ty: member.clone(),
54                    clock: reg.clock.clone(),
55                    reset: reg.reset.clone(),
56                    initial: None,
57                    value: split_name.clone(),
58                    loc: None,
59                    traced: None,
60                },
61                expr_idtracker,
62            );
63
64            let split_stmts = vec![split_stmt]
65                .into_iter()
66                .chain(reg_stmts)
67                .collect::<Vec<_>>();
68
69            (reg_name, split_stmts)
70        })
71        .unzip();
72
73    let new_compound = Statement::Binding(Binding {
74        name: reg.name.clone(),
75        operator: Operator::ConstructTuple,
76        operands: reg_names,
77        ty: reg.ty.clone(),
78        loc: None,
79    });
80
81    split_stmts
82        .into_iter()
83        .flatten()
84        .chain(vec![new_compound])
85        .collect()
86}
87
88fn split_compound_reg(reg: &Register, expr_idtracker: &mut ExprIdTracker) -> Vec<Statement> {
89    if reg.initial.is_some() {
90        return vec![Statement::Register(reg.clone())];
91    }
92
93    match &reg.ty {
94        Type::Int(_)
95        | Type::UInt(_)
96        | Type::Bool
97        | Type::InOut(_)
98        | Type::Enum(_)
99        | Type::Backward(_)
100        | Type::Memory { .. } => vec![Statement::Register(reg.clone())],
101        Type::Tuple(members) => generate_split_code(reg, members, expr_idtracker),
102        Type::Struct(members) => generate_split_code(
103            reg,
104            &members.iter().map(|(_, ty)| ty.clone()).collect(),
105            expr_idtracker,
106        ),
107        // NOTE: Arrays are currently split as if they were tuples. This means that
108        // things will be a bit weird in the MIR, but it does make codegen for all this
109        // much easier as it doesn't require generating array indices as runtime constants.
110        Type::Array { inner, length } => {
111            if let Some(length) = length.to_usize() {
112                generate_split_code(
113                    reg,
114                    &(0..(length)).map(|_| *inner.clone()).collect::<_>(),
115                    expr_idtracker,
116                )
117            } else {
118                vec![Statement::Register(reg.clone())]
119            }
120        }
121    }
122}
123
124#[cfg(test)]
125mod test {
126    use colored::Colorize;
127
128    use spade_common::id_tracker::ExprIdTracker;
129
130    use super::SplitCompoundRegs;
131    use crate::passes::MirPass;
132    use crate::{self as spade_mir, assert_same_mir};
133    use crate::{entity, types::Type};
134
135    #[test]
136    fn splitting_tuple_works() {
137        let members = vec![Type::int(4), Type::int(8)];
138        let ty = Type::Tuple(vec![Type::int(4), Type::int(8)]);
139
140        let before = entity!("pong"; ("_i_clk", n(0, "clk"), Type::Bool, "val", n(2, "val"), ty.clone()) -> Type::int(6); {
141            (reg n(1, "value"); ty.clone(); clock (n(0, "clk")); n(2, "val"));
142        } => n(1, "value"));
143
144        let pass = SplitCompoundRegs {};
145        let mut after = before.clone();
146        after.statements =
147            pass.transform_statements(&before.statements, &mut ExprIdTracker::new_at(100));
148
149        let expected = entity!("pong"; ("_i_clk", n(0, "clk"), Type::Bool, "val", n(2, "val"), ty.clone()) -> Type::int(6); {
150            (e(10); Type::int(4); IndexTuple((0, members.clone())); n(2, "val"));
151            (reg e(11); Type::int(4); clock (n(0, "clk")); e(10));
152            (e(20); Type::int(8); IndexTuple((1, members.clone())); n(2, "val"));
153            (reg e(21); Type::int(8); clock (n(0, "clk")); e(20));
154
155            (n(1, "value"); ty; ConstructTuple; e(11), e(21));
156        } => n(1, "value"));
157
158        assert_same_mir!(&after, &expected);
159    }
160
161    #[test]
162    fn splitting_struct_works() {
163        let members = vec![Type::int(4), Type::int(8)];
164        let ty = Type::Struct(vec![
165            ("a".to_string(), Type::int(4)),
166            ("b".to_string(), Type::int(8)),
167        ]);
168
169        let before = entity!("pong"; ("_i_clk", n(0, "clk"), Type::Bool, "val", n(2, "val"), ty.clone()) -> Type::int(6); {
170            (reg n(1, "value"); ty.clone(); clock (n(0, "clk")); n(2, "val"));
171        } => n(1, "value"));
172
173        let pass = SplitCompoundRegs {};
174        let mut after = before.clone();
175        after.statements =
176            pass.transform_statements(&before.statements, &mut ExprIdTracker::new_at(100));
177
178        let expected = entity!("pong"; ("_i_clk", n(0, "clk"), Type::Bool, "val", n(2, "val"), ty.clone()) -> Type::int(6); {
179            (e(10); Type::int(4); IndexTuple((0, members.clone())); n(2, "val"));
180            (reg e(11); Type::int(4); clock (n(0, "clk")); e(10));
181            (e(20); Type::int(8); IndexTuple((1, members.clone())); n(2, "val"));
182            (reg e(21); Type::int(8); clock (n(0, "clk")); e(20));
183
184            (n(1, "value"); ty; ConstructTuple; e(11), e(21));
185        } => n(1, "value"));
186
187        assert_same_mir!(&after, &expected);
188    }
189
190    #[test]
191    fn splitting_compound_compounds_works() {
192        let inner_members = vec![Type::int(4), Type::int(8)];
193        let inner_ty = Type::Tuple(inner_members.clone());
194        let members = vec![Type::int(4), inner_ty.clone()];
195        let ty = Type::Tuple(members.clone());
196
197        let before = entity!("pong"; ("_i_clk", n(0, "clk"), Type::Bool, "val", n(2, "val"), ty.clone()) -> Type::int(6); {
198            (reg n(1, "value"); ty.clone(); clock (n(0, "clk")); n(2, "val"));
199        } => n(1, "value"));
200
201        let pass = SplitCompoundRegs {};
202        let mut after = before.clone();
203        after.statements =
204            pass.transform_statements(&before.statements, &mut ExprIdTracker::new_at(100));
205
206        let expected = entity!("pong"; ("_i_clk", n(0, "clk"), Type::Bool, "val", n(2, "val"), ty.clone()) -> Type::int(6); {
207            (e(10); Type::int(4); IndexTuple((0, members.clone())); n(2, "val"));
208            (reg e(11); Type::int(4); clock (n(0, "clk")); e(10));
209            (e(20); inner_ty.clone(); IndexTuple((1, members.clone())); n(2, "val"));
210
211            (e(30); Type::int(4); IndexTuple((0, inner_members.clone())); e(20));
212            (reg e(31); Type::int(4); clock (n(0, "clk")); e(30));
213            (e(40); Type::int(8); IndexTuple((1, inner_members.clone())); e(20));
214            (reg e(41); Type::int(8); clock (n(0, "clk")); e(40));
215
216            (e(21); inner_ty; ConstructTuple; e(31), e(41));
217
218            (n(1, "value"); ty; ConstructTuple; e(11), e(21));
219        } => n(1, "value"));
220
221        assert_same_mir!(&after, &expected);
222    }
223}