spade_mir/passes/
auto_clock_gating.rs

1use num::BigUint;
2use spade_common::{id_tracker::ExprIdTracker, location_info::Loc, num_ext::InfallibleToBigUint};
3
4use crate::{types::Type, Binding, Operator, Register, Statement, ValueName};
5
6use super::MirPass;
7
8/// Splits an 2 variant enum with variant 0 having payload of size 0 and variant 1
9/// having another size into a tag and payload
10fn split_trivial_tag_value(
11    value: &ValueName,
12    variants: &Vec<Vec<Type>>,
13    statements: &mut Vec<Statement>,
14    expr_idtracker: &mut ExprIdTracker,
15    loc: &Option<Loc<()>>,
16) -> (ValueName, ValueName) {
17    let tag_name = ValueName::Expr(expr_idtracker.next());
18    let payload_name = ValueName::Expr(expr_idtracker.next());
19
20    let payload_size = variants[1].iter().map(|v| v.size()).sum::<BigUint>();
21    statements.push(Statement::Binding(Binding {
22        name: tag_name.clone(),
23        operator: Operator::RangeIndexBits {
24            start: payload_size.clone(),
25            end_exclusive: &payload_size + 1u32.to_biguint(),
26        },
27        operands: vec![value.clone()],
28        ty: Type::Bool,
29        loc: *loc,
30    }));
31    statements.push(Statement::Binding(Binding {
32        name: payload_name.clone(),
33        operator: Operator::RangeIndexBits {
34            start: 0u32.to_biguint(),
35            end_exclusive: payload_size,
36        },
37        operands: vec![value.clone()],
38        ty: Type::Tuple(variants[1].clone()),
39        loc: *loc,
40    }));
41
42    (tag_name, payload_name)
43}
44
45impl Register {
46    fn perform_trivial_gating(&self, expr_idtracker: &mut ExprIdTracker) -> Option<Vec<Statement>> {
47        // FIXME: For now, we'll not split registers initial values because those would need
48        // special treatment since their values are comptime-evaluated
49        if self.initial.is_some() {
50            return None;
51        }
52        match &self.ty {
53            crate::types::Type::Enum(variants) => {
54                if variants.len() == 2 && variants[0].len() == 0 {
55                    let mut new_statements = vec![];
56
57                    let tag_reg_name = ValueName::Expr(expr_idtracker.next());
58                    let payload_reg_name = ValueName::Expr(expr_idtracker.next());
59                    let payload_reg_value_name = ValueName::Expr(expr_idtracker.next());
60
61                    let payload_type = Type::Tuple(variants[1].clone());
62
63                    let (value_tag, value_payload) = split_trivial_tag_value(
64                        &self.value,
65                        variants,
66                        &mut new_statements,
67                        expr_idtracker,
68                        &self.loc,
69                    );
70                    let (reset_tag, reset_payload) =
71                        if let Some((reset_trig, reset_val)) = &self.reset {
72                            let (tag, payload) = split_trivial_tag_value(
73                                reset_val,
74                                variants,
75                                &mut new_statements,
76                                expr_idtracker,
77                                &self.loc,
78                            );
79                            (
80                                Some((reset_trig.clone(), tag)),
81                                Some((reset_trig.clone(), payload)),
82                            )
83                        } else {
84                            (None, None)
85                        };
86
87                    new_statements.push(Statement::Register(Register {
88                        name: tag_reg_name.clone(),
89                        ty: Type::Bool,
90                        clock: self.clock.clone(),
91                        reset: reset_tag,
92                        initial: self.initial.as_ref().map(|_| panic!("Had initial")),
93                        value: value_tag.clone(),
94                        loc: self.loc,
95                        // FIXME: wal-tracing breaks with this change
96                        traced: None,
97                    }));
98                    new_statements.push(Statement::Binding(Binding {
99                        name: payload_reg_value_name.clone(),
100                        operator: Operator::Select,
101                        operands: vec![
102                            value_tag.clone(),
103                            value_payload.clone(),
104                            payload_reg_name.clone(),
105                        ],
106                        ty: payload_type.clone(),
107                        loc: self.loc,
108                    }));
109                    new_statements.push(Statement::Register(Register {
110                        name: payload_reg_name.clone(),
111                        ty: payload_type,
112                        clock: self.clock.clone(),
113                        reset: reset_payload,
114                        initial: self.initial.as_ref().map(|_| panic!("Had initial")),
115                        value: payload_reg_value_name,
116                        loc: self.loc,
117                        // FIXME: wal-tracing breaks with this change
118                        traced: None,
119                    }));
120                    new_statements.push(Statement::Binding(Binding {
121                        name: self.name.clone(),
122                        operator: Operator::Concat,
123                        operands: vec![tag_reg_name.clone(), payload_reg_name.clone()],
124                        ty: self.ty.clone(),
125                        loc: self.loc,
126                    }));
127
128                    Some(new_statements)
129                } else {
130                    None
131                }
132            }
133            _ => None,
134        }
135    }
136}
137
138pub struct AutoGating {}
139
140impl MirPass for AutoGating {
141    fn transform_statements(
142        &self,
143        stmts: &[Statement],
144        expr_idtracker: &mut ExprIdTracker,
145    ) -> Vec<Statement> {
146        stmts
147            .iter()
148            .flat_map(|stmt| match stmt {
149                Statement::Register(reg) => reg
150                    .perform_trivial_gating(expr_idtracker)
151                    .unwrap_or_else(|| vec![stmt.clone()]),
152                other => vec![other.clone()],
153            })
154            .collect()
155    }
156
157    fn name(&self) -> &'static str {
158        "enum_clock_gating"
159    }
160}