spade_mir/
types.rs

1use num::{BigUint, Zero};
2use spade_common::num_ext::InfallibleToBigUint;
3
4#[derive(Clone, PartialEq, Eq, Hash, Debug)]
5pub enum Type {
6    Int(BigUint),
7    UInt(BigUint),
8    Bool,
9    Tuple(Vec<Type>),
10    Struct(Vec<(String, Type)>),
11    Array {
12        inner: Box<Type>,
13        length: BigUint,
14    },
15    Memory {
16        inner: Box<Type>,
17        length: BigUint,
18    },
19    Enum(Vec<Vec<Type>>),
20    /// A type in which values flow the opposite way compared to normal types. When a type
21    /// containing a Backward<T> is returned, the module 'returning' it has an additional *input*
22    /// for the wire, and if it takes an input with, n additional *output* port is created.
23    Backward(Box<Type>),
24    InOut(Box<Type>),
25}
26
27impl Type {
28    pub fn int(val: u32) -> Self {
29        Self::Int(val.to_biguint())
30    }
31    pub fn uint(val: u32) -> Self {
32        Self::UInt(val.to_biguint())
33    }
34    pub fn backward(inner: Type) -> Self {
35        Self::Backward(Box::new(inner))
36    }
37    pub fn unit() -> Self {
38        Self::Tuple(Vec::new())
39    }
40
41    pub fn size(&self) -> BigUint {
42        match self {
43            Type::Int(len) => len.clone(),
44            Type::UInt(len) => len.clone(),
45            Type::Bool => 1u32.to_biguint(),
46            Type::Tuple(inner) => inner.iter().map(Type::size).sum::<BigUint>(),
47            Type::Struct(inner) => inner.iter().map(|(_, t)| t.size()).sum::<BigUint>(),
48            Type::Enum(inner) => {
49                let discriminant_size = (inner.len() as f32).log2().ceil() as u64;
50
51                let members_size = inner
52                    .iter()
53                    .map(|m| m.iter().map(|t| t.size()).sum())
54                    .max()
55                    .unwrap_or(BigUint::zero());
56
57                discriminant_size + members_size
58            }
59            Type::Array { inner, length } => inner.size() * length,
60            Type::Memory { inner, length } => inner.size() * length,
61            Type::Backward(_) => BigUint::zero(),
62            Type::InOut(inner) => inner.size(),
63        }
64    }
65
66    pub fn backward_size(&self) -> BigUint {
67        match self {
68            Type::Backward(inner) => inner.size(),
69            Type::Int(_) | Type::UInt(_) | Type::Bool => BigUint::zero(),
70            Type::Array { inner, length } => inner.backward_size() * length,
71            Type::Enum(inner) => {
72                for v in inner {
73                    for i in v {
74                        if i.backward_size() != BigUint::zero() {
75                            unreachable!("Enums cannot have output wires as payload")
76                        }
77                    }
78                }
79                BigUint::zero()
80            }
81            Type::Memory { inner, .. } => {
82                if inner.backward_size() != BigUint::zero() {
83                    unreachable!("Memory cannot contain output wires")
84                };
85                BigUint::zero()
86            }
87            Type::Tuple(inner) => inner.iter().map(Type::backward_size).sum::<BigUint>(),
88            Type::Struct(inner) => inner
89                .iter()
90                .map(|(_, t)| t.backward_size())
91                .sum::<BigUint>(),
92            Type::InOut(_) => BigUint::zero(),
93        }
94    }
95
96    pub fn assume_enum(&self) -> &Vec<Vec<Type>> {
97        if let Type::Enum(inner) = self {
98            inner
99        } else {
100            panic!("Assumed enum for a type which was not")
101        }
102    }
103
104    pub fn must_use(&self) -> bool {
105        match self {
106            Type::Tuple(unit) if unit.is_empty() => false,
107            _ => true,
108        }
109    }
110}
111
112impl std::fmt::Display for Type {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        match self {
115            Type::Int(val) => write!(f, "int<{}>", val),
116            Type::UInt(val) => write!(f, "uint<{}>", val),
117            Type::Bool => write!(f, "bool"),
118            Type::Tuple(inner) => {
119                let inner = inner
120                    .iter()
121                    .map(|p| format!("{}", p))
122                    .collect::<Vec<_>>()
123                    .join(", ");
124                write!(f, "({})", inner)
125            }
126            Type::Struct(inner) => {
127                let inner = inner
128                    .iter()
129                    .map(|(n, t)| format!("{n}: {t}"))
130                    .collect::<Vec<_>>()
131                    .join(", ");
132                write!(f, "{{{}}}", inner)
133            }
134            Type::Array { inner, length } => {
135                write!(f, "[{}; {}]", inner, length)
136            }
137            Type::Memory { inner, length } => {
138                write!(f, "Memory[{}; {}]", inner, length)
139            }
140            Type::Enum(inner) => {
141                let inner = inner
142                    .iter()
143                    .map(|variant| {
144                        let members = variant
145                            .iter()
146                            .map(|t| format!("{}", t))
147                            .collect::<Vec<_>>()
148                            .join(", ");
149                        format!("option [{}]", members)
150                    })
151                    .collect::<Vec<_>>()
152                    .join(", ");
153
154                write!(f, "enum {}", inner)
155            }
156            Type::Backward(inner) => {
157                write!(f, "&mut ({inner})")
158            }
159            Type::InOut(inner) => {
160                write!(f, "inout<{inner}>")
161            }
162        }
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn pure_enum_size_is_correct() {
172        // 2 variant enum
173        assert_eq!(Type::Enum(vec![vec![], vec![]]).size(), 1u32.to_biguint());
174    }
175
176    #[test]
177    fn enum_with_payload_size_is_correct() {
178        // 2 variant enum
179        assert_eq!(
180            Type::Enum(vec![vec![Type::Int(5u32.to_biguint())], vec![Type::Bool]]).size(),
181            6u32.to_biguint()
182        );
183    }
184
185    #[test]
186    fn single_variant_enum_is_0_bits() {
187        assert_eq!(Type::Enum(vec![vec![]]).size(), BigUint::zero());
188    }
189}