spade_mir/passes/
auto_clock_gating.rs
1use num::BigUint;
2use spade_common::{id_tracker::ExprIdTracker, location_info::Loc, num_ext::InfallibleToBigUint};
3
4use crate::{types::Type, Binding, Operator, Register, Statement, ValueName};
5
6use super::MirPass;
7
8fn split_trivial_tag_value(
11 value: &ValueName,
12 variants: &Vec<Vec<Type>>,
13 statements: &mut Vec<Statement>,
14 expr_idtracker: &mut ExprIdTracker,
15 loc: &Option<Loc<()>>,
16) -> (ValueName, ValueName) {
17 let tag_name = ValueName::Expr(expr_idtracker.next());
18 let payload_name = ValueName::Expr(expr_idtracker.next());
19
20 let payload_size = variants[1].iter().map(|v| v.size()).sum::<BigUint>();
21 statements.push(Statement::Binding(Binding {
22 name: tag_name.clone(),
23 operator: Operator::RangeIndexBits {
24 start: payload_size.clone(),
25 end_exclusive: &payload_size + 1u32.to_biguint(),
26 },
27 operands: vec![value.clone()],
28 ty: Type::Bool,
29 loc: *loc,
30 }));
31 statements.push(Statement::Binding(Binding {
32 name: payload_name.clone(),
33 operator: Operator::RangeIndexBits {
34 start: 0u32.to_biguint(),
35 end_exclusive: payload_size,
36 },
37 operands: vec![value.clone()],
38 ty: Type::Tuple(variants[1].clone()),
39 loc: *loc,
40 }));
41
42 (tag_name, payload_name)
43}
44
45impl Register {
46 fn perform_trivial_gating(&self, expr_idtracker: &mut ExprIdTracker) -> Option<Vec<Statement>> {
47 if self.initial.is_some() {
50 return None;
51 }
52 match &self.ty {
53 crate::types::Type::Enum(variants) => {
54 if variants.len() == 2 && variants[0].len() == 0 {
55 let mut new_statements = vec![];
56
57 let tag_reg_name = ValueName::Expr(expr_idtracker.next());
58 let payload_reg_name = ValueName::Expr(expr_idtracker.next());
59 let payload_reg_value_name = ValueName::Expr(expr_idtracker.next());
60
61 let payload_type = Type::Tuple(variants[1].clone());
62
63 let (value_tag, value_payload) = split_trivial_tag_value(
64 &self.value,
65 variants,
66 &mut new_statements,
67 expr_idtracker,
68 &self.loc,
69 );
70 let (reset_tag, reset_payload) =
71 if let Some((reset_trig, reset_val)) = &self.reset {
72 let (tag, payload) = split_trivial_tag_value(
73 reset_val,
74 variants,
75 &mut new_statements,
76 expr_idtracker,
77 &self.loc,
78 );
79 (
80 Some((reset_trig.clone(), tag)),
81 Some((reset_trig.clone(), payload)),
82 )
83 } else {
84 (None, None)
85 };
86
87 new_statements.push(Statement::Register(Register {
88 name: tag_reg_name.clone(),
89 ty: Type::Bool,
90 clock: self.clock.clone(),
91 reset: reset_tag,
92 initial: self.initial.as_ref().map(|_| panic!("Had initial")),
93 value: value_tag.clone(),
94 loc: self.loc,
95 traced: None,
97 }));
98 new_statements.push(Statement::Binding(Binding {
99 name: payload_reg_value_name.clone(),
100 operator: Operator::Select,
101 operands: vec![
102 value_tag.clone(),
103 value_payload.clone(),
104 payload_reg_name.clone(),
105 ],
106 ty: payload_type.clone(),
107 loc: self.loc,
108 }));
109 new_statements.push(Statement::Register(Register {
110 name: payload_reg_name.clone(),
111 ty: payload_type,
112 clock: self.clock.clone(),
113 reset: reset_payload,
114 initial: self.initial.as_ref().map(|_| panic!("Had initial")),
115 value: payload_reg_value_name,
116 loc: self.loc,
117 traced: None,
119 }));
120 new_statements.push(Statement::Binding(Binding {
121 name: self.name.clone(),
122 operator: Operator::Concat,
123 operands: vec![tag_reg_name.clone(), payload_reg_name.clone()],
124 ty: self.ty.clone(),
125 loc: self.loc,
126 }));
127
128 Some(new_statements)
129 } else {
130 None
131 }
132 }
133 _ => None,
134 }
135 }
136}
137
138pub struct AutoGating {}
139
140impl MirPass for AutoGating {
141 fn transform_statements(
142 &self,
143 stmts: &[Statement],
144 expr_idtracker: &mut ExprIdTracker,
145 ) -> Vec<Statement> {
146 stmts
147 .iter()
148 .flat_map(|stmt| match stmt {
149 Statement::Register(reg) => reg
150 .perform_trivial_gating(expr_idtracker)
151 .unwrap_or_else(|| vec![stmt.clone()]),
152 other => vec![other.clone()],
153 })
154 .collect()
155 }
156
157 fn name(&self) -> &'static str {
158 "enum_clock_gating"
159 }
160}