spade_mir/
types.rs
1use num::{BigUint, Zero};
2use spade_common::num_ext::InfallibleToBigUint;
3
4#[derive(Clone, PartialEq, Eq, Hash, Debug)]
5pub enum Type {
6 Int(BigUint),
7 UInt(BigUint),
8 Bool,
9 Tuple(Vec<Type>),
10 Struct(Vec<(String, Type)>),
11 Array {
12 inner: Box<Type>,
13 length: BigUint,
14 },
15 Memory {
16 inner: Box<Type>,
17 length: BigUint,
18 },
19 Enum(Vec<Vec<Type>>),
20 Backward(Box<Type>),
24 InOut(Box<Type>),
25}
26
27impl Type {
28 pub fn int(val: u32) -> Self {
29 Self::Int(val.to_biguint())
30 }
31 pub fn uint(val: u32) -> Self {
32 Self::UInt(val.to_biguint())
33 }
34 pub fn backward(inner: Type) -> Self {
35 Self::Backward(Box::new(inner))
36 }
37 pub fn unit() -> Self {
38 Self::Tuple(Vec::new())
39 }
40
41 pub fn size(&self) -> BigUint {
42 match self {
43 Type::Int(len) => len.clone(),
44 Type::UInt(len) => len.clone(),
45 Type::Bool => 1u32.to_biguint(),
46 Type::Tuple(inner) => inner.iter().map(Type::size).sum::<BigUint>(),
47 Type::Struct(inner) => inner.iter().map(|(_, t)| t.size()).sum::<BigUint>(),
48 Type::Enum(inner) => {
49 let discriminant_size = (inner.len() as f32).log2().ceil() as u64;
50
51 let members_size = inner
52 .iter()
53 .map(|m| m.iter().map(|t| t.size()).sum())
54 .max()
55 .unwrap_or(BigUint::zero());
56
57 discriminant_size + members_size
58 }
59 Type::Array { inner, length } => inner.size() * length,
60 Type::Memory { inner, length } => inner.size() * length,
61 Type::Backward(_) => BigUint::zero(),
62 Type::InOut(inner) => inner.size(),
63 }
64 }
65
66 pub fn backward_size(&self) -> BigUint {
67 match self {
68 Type::Backward(inner) => inner.size(),
69 Type::Int(_) | Type::UInt(_) | Type::Bool => BigUint::zero(),
70 Type::Array { inner, length } => inner.backward_size() * length,
71 Type::Enum(inner) => {
72 for v in inner {
73 for i in v {
74 if i.backward_size() != BigUint::zero() {
75 unreachable!("Enums cannot have output wires as payload")
76 }
77 }
78 }
79 BigUint::zero()
80 }
81 Type::Memory { inner, .. } => {
82 if inner.backward_size() != BigUint::zero() {
83 unreachable!("Memory cannot contain output wires")
84 };
85 BigUint::zero()
86 }
87 Type::Tuple(inner) => inner.iter().map(Type::backward_size).sum::<BigUint>(),
88 Type::Struct(inner) => inner
89 .iter()
90 .map(|(_, t)| t.backward_size())
91 .sum::<BigUint>(),
92 Type::InOut(_) => BigUint::zero(),
93 }
94 }
95
96 pub fn assume_enum(&self) -> &Vec<Vec<Type>> {
97 if let Type::Enum(inner) = self {
98 inner
99 } else {
100 panic!("Assumed enum for a type which was not")
101 }
102 }
103
104 pub fn must_use(&self) -> bool {
105 match self {
106 Type::Tuple(unit) if unit.is_empty() => false,
107 _ => true,
108 }
109 }
110}
111
112impl std::fmt::Display for Type {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 match self {
115 Type::Int(val) => write!(f, "int<{}>", val),
116 Type::UInt(val) => write!(f, "uint<{}>", val),
117 Type::Bool => write!(f, "bool"),
118 Type::Tuple(inner) => {
119 let inner = inner
120 .iter()
121 .map(|p| format!("{}", p))
122 .collect::<Vec<_>>()
123 .join(", ");
124 write!(f, "({})", inner)
125 }
126 Type::Struct(inner) => {
127 let inner = inner
128 .iter()
129 .map(|(n, t)| format!("{n}: {t}"))
130 .collect::<Vec<_>>()
131 .join(", ");
132 write!(f, "{{{}}}", inner)
133 }
134 Type::Array { inner, length } => {
135 write!(f, "[{}; {}]", inner, length)
136 }
137 Type::Memory { inner, length } => {
138 write!(f, "Memory[{}; {}]", inner, length)
139 }
140 Type::Enum(inner) => {
141 let inner = inner
142 .iter()
143 .map(|variant| {
144 let members = variant
145 .iter()
146 .map(|t| format!("{}", t))
147 .collect::<Vec<_>>()
148 .join(", ");
149 format!("option [{}]", members)
150 })
151 .collect::<Vec<_>>()
152 .join(", ");
153
154 write!(f, "enum {}", inner)
155 }
156 Type::Backward(inner) => {
157 write!(f, "&mut ({inner})")
158 }
159 Type::InOut(inner) => {
160 write!(f, "inout<{inner}>")
161 }
162 }
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn pure_enum_size_is_correct() {
172 assert_eq!(Type::Enum(vec![vec![], vec![]]).size(), 1u32.to_biguint());
174 }
175
176 #[test]
177 fn enum_with_payload_size_is_correct() {
178 assert_eq!(
180 Type::Enum(vec![vec![Type::Int(5u32.to_biguint())], vec![Type::Bool]]).size(),
181 6u32.to_biguint()
182 );
183 }
184
185 #[test]
186 fn single_variant_enum_is_0_bits() {
187 assert_eq!(Type::Enum(vec![vec![]]).size(), BigUint::zero());
188 }
189}