1use num::ToPrimitive;
2
3use spade_common::id_tracker::ExprIdTracker;
4
5use crate::{types::Type, Binding, Operator, Register, Statement, ValueName};
6
7use super::MirPass;
8
9pub struct SplitCompoundRegs {}
10
11impl MirPass for SplitCompoundRegs {
12 fn name(&self) -> &'static str {
13 "split_compound_regs"
14 }
15
16 fn transform_statements(
17 &self,
18 stmts: &[Statement],
19 expr_idtracker: &mut ExprIdTracker,
20 ) -> Vec<Statement> {
21 stmts
22 .iter()
23 .flat_map(|stmt| match stmt {
24 Statement::Register(reg) => split_compound_reg(reg, expr_idtracker),
25 other => vec![other.clone()],
26 })
27 .collect()
28 }
29}
30
31fn generate_split_code(
32 reg: &Register,
33 members: &Vec<Type>,
34 expr_idtracker: &mut ExprIdTracker,
35) -> Vec<Statement> {
36 let (reg_names, split_stmts): (Vec<_>, Vec<_>) = members
37 .iter()
38 .enumerate()
39 .map(|(i, member)| {
40 let split_name = ValueName::Expr(expr_idtracker.next());
41 let reg_name = ValueName::Expr(expr_idtracker.next());
42 let split_stmt = Statement::Binding(Binding {
43 name: split_name.clone(),
44 operator: Operator::IndexTuple(i as u64, members.clone()),
45 operands: vec![reg.value.clone()],
46 ty: member.clone(),
47 loc: None,
48 });
49
50 let reg_stmts = split_compound_reg(
51 &Register {
52 name: reg_name.clone(),
53 ty: member.clone(),
54 clock: reg.clock.clone(),
55 reset: reg.reset.clone(),
56 initial: None,
57 value: split_name.clone(),
58 loc: None,
59 traced: None,
60 },
61 expr_idtracker,
62 );
63
64 let split_stmts = vec![split_stmt]
65 .into_iter()
66 .chain(reg_stmts)
67 .collect::<Vec<_>>();
68
69 (reg_name, split_stmts)
70 })
71 .unzip();
72
73 let new_compound = Statement::Binding(Binding {
74 name: reg.name.clone(),
75 operator: Operator::ConstructTuple,
76 operands: reg_names,
77 ty: reg.ty.clone(),
78 loc: None,
79 });
80
81 split_stmts
82 .into_iter()
83 .flatten()
84 .chain(vec![new_compound])
85 .collect()
86}
87
88fn split_compound_reg(reg: &Register, expr_idtracker: &mut ExprIdTracker) -> Vec<Statement> {
89 if reg.initial.is_some() {
90 return vec![Statement::Register(reg.clone())];
91 }
92
93 match ®.ty {
94 Type::Int(_)
95 | Type::UInt(_)
96 | Type::Bool
97 | Type::InOut(_)
98 | Type::Enum(_)
99 | Type::Backward(_)
100 | Type::Memory { .. } => vec![Statement::Register(reg.clone())],
101 Type::Tuple(members) => generate_split_code(reg, members, expr_idtracker),
102 Type::Struct(members) => generate_split_code(
103 reg,
104 &members.iter().map(|(_, ty)| ty.clone()).collect(),
105 expr_idtracker,
106 ),
107 Type::Array { inner, length } => {
111 if let Some(length) = length.to_usize() {
112 generate_split_code(
113 reg,
114 &(0..(length)).map(|_| *inner.clone()).collect::<_>(),
115 expr_idtracker,
116 )
117 } else {
118 vec![Statement::Register(reg.clone())]
119 }
120 }
121 }
122}
123
124#[cfg(test)]
125mod test {
126 use colored::Colorize;
127
128 use spade_common::id_tracker::ExprIdTracker;
129
130 use super::SplitCompoundRegs;
131 use crate::passes::MirPass;
132 use crate::{self as spade_mir, assert_same_mir};
133 use crate::{entity, types::Type};
134
135 #[test]
136 fn splitting_tuple_works() {
137 let members = vec![Type::int(4), Type::int(8)];
138 let ty = Type::Tuple(vec![Type::int(4), Type::int(8)]);
139
140 let before = entity!("pong"; ("_i_clk", n(0, "clk"), Type::Bool, "val", n(2, "val"), ty.clone()) -> Type::int(6); {
141 (reg n(1, "value"); ty.clone(); clock (n(0, "clk")); n(2, "val"));
142 } => n(1, "value"));
143
144 let pass = SplitCompoundRegs {};
145 let mut after = before.clone();
146 after.statements =
147 pass.transform_statements(&before.statements, &mut ExprIdTracker::new_at(100));
148
149 let expected = entity!("pong"; ("_i_clk", n(0, "clk"), Type::Bool, "val", n(2, "val"), ty.clone()) -> Type::int(6); {
150 (e(10); Type::int(4); IndexTuple((0, members.clone())); n(2, "val"));
151 (reg e(11); Type::int(4); clock (n(0, "clk")); e(10));
152 (e(20); Type::int(8); IndexTuple((1, members.clone())); n(2, "val"));
153 (reg e(21); Type::int(8); clock (n(0, "clk")); e(20));
154
155 (n(1, "value"); ty; ConstructTuple; e(11), e(21));
156 } => n(1, "value"));
157
158 assert_same_mir!(&after, &expected);
159 }
160
161 #[test]
162 fn splitting_struct_works() {
163 let members = vec![Type::int(4), Type::int(8)];
164 let ty = Type::Struct(vec![
165 ("a".to_string(), Type::int(4)),
166 ("b".to_string(), Type::int(8)),
167 ]);
168
169 let before = entity!("pong"; ("_i_clk", n(0, "clk"), Type::Bool, "val", n(2, "val"), ty.clone()) -> Type::int(6); {
170 (reg n(1, "value"); ty.clone(); clock (n(0, "clk")); n(2, "val"));
171 } => n(1, "value"));
172
173 let pass = SplitCompoundRegs {};
174 let mut after = before.clone();
175 after.statements =
176 pass.transform_statements(&before.statements, &mut ExprIdTracker::new_at(100));
177
178 let expected = entity!("pong"; ("_i_clk", n(0, "clk"), Type::Bool, "val", n(2, "val"), ty.clone()) -> Type::int(6); {
179 (e(10); Type::int(4); IndexTuple((0, members.clone())); n(2, "val"));
180 (reg e(11); Type::int(4); clock (n(0, "clk")); e(10));
181 (e(20); Type::int(8); IndexTuple((1, members.clone())); n(2, "val"));
182 (reg e(21); Type::int(8); clock (n(0, "clk")); e(20));
183
184 (n(1, "value"); ty; ConstructTuple; e(11), e(21));
185 } => n(1, "value"));
186
187 assert_same_mir!(&after, &expected);
188 }
189
190 #[test]
191 fn splitting_compound_compounds_works() {
192 let inner_members = vec![Type::int(4), Type::int(8)];
193 let inner_ty = Type::Tuple(inner_members.clone());
194 let members = vec![Type::int(4), inner_ty.clone()];
195 let ty = Type::Tuple(members.clone());
196
197 let before = entity!("pong"; ("_i_clk", n(0, "clk"), Type::Bool, "val", n(2, "val"), ty.clone()) -> Type::int(6); {
198 (reg n(1, "value"); ty.clone(); clock (n(0, "clk")); n(2, "val"));
199 } => n(1, "value"));
200
201 let pass = SplitCompoundRegs {};
202 let mut after = before.clone();
203 after.statements =
204 pass.transform_statements(&before.statements, &mut ExprIdTracker::new_at(100));
205
206 let expected = entity!("pong"; ("_i_clk", n(0, "clk"), Type::Bool, "val", n(2, "val"), ty.clone()) -> Type::int(6); {
207 (e(10); Type::int(4); IndexTuple((0, members.clone())); n(2, "val"));
208 (reg e(11); Type::int(4); clock (n(0, "clk")); e(10));
209 (e(20); inner_ty.clone(); IndexTuple((1, members.clone())); n(2, "val"));
210
211 (e(30); Type::int(4); IndexTuple((0, inner_members.clone())); e(20));
212 (reg e(31); Type::int(4); clock (n(0, "clk")); e(30));
213 (e(40); Type::int(8); IndexTuple((1, inner_members.clone())); e(20));
214 (reg e(41); Type::int(8); clock (n(0, "clk")); e(40));
215
216 (e(21); inner_ty; ConstructTuple; e(31), e(41));
217
218 (n(1, "value"); ty; ConstructTuple; e(11), e(21));
219 } => n(1, "value"));
220
221 assert_same_mir!(&after, &expected);
222 }
223}