1use std::borrow::BorrowMut;
2
3use crate::{ConstGenericWithId, Pattern, TypeExpression, TypeParam};
4
5use super::{Block, NameID};
6use num::{BigInt, BigUint};
7use serde::{Deserialize, Serialize};
8use spade_common::{
9 id_tracker::ExprID,
10 location_info::{Loc, WithLocation},
11 name::{Identifier, Path},
12 num_ext::InfallibleToBigInt,
13};
14
15#[derive(Clone, Copy, PartialEq, Debug, Serialize, Deserialize)]
16pub enum BinaryOperator {
17 Add,
18 Sub,
19 Mul,
20 Div,
21 Mod,
22 Eq,
23 NotEq,
24 Gt,
25 Lt,
26 Ge,
27 Le,
28 LeftShift,
29 RightShift,
30 ArithmeticRightShift,
31 LogicalAnd,
32 LogicalOr,
33 LogicalXor,
34 BitwiseOr,
35 BitwiseAnd,
36 BitwiseXor,
37}
38
39impl std::fmt::Display for BinaryOperator {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 BinaryOperator::Add => write!(f, "+"),
43 BinaryOperator::Sub => write!(f, "-"),
44 BinaryOperator::Mul => write!(f, "*"),
45 BinaryOperator::Div => write!(f, "/"),
46 BinaryOperator::Mod => write!(f, "%"),
47 BinaryOperator::Eq => write!(f, "=="),
48 BinaryOperator::NotEq => write!(f, "!="),
49 BinaryOperator::Gt => write!(f, ">"),
50 BinaryOperator::Lt => write!(f, "<"),
51 BinaryOperator::Ge => write!(f, ">="),
52 BinaryOperator::Le => write!(f, "<="),
53 BinaryOperator::LeftShift => write!(f, ">>"),
54 BinaryOperator::RightShift => write!(f, "<<"),
55 BinaryOperator::ArithmeticRightShift => write!(f, ">>>"),
56 BinaryOperator::LogicalAnd => write!(f, "&&"),
57 BinaryOperator::LogicalOr => write!(f, "||"),
58 BinaryOperator::LogicalXor => write!(f, "^^"),
59 BinaryOperator::BitwiseOr => write!(f, "|"),
60 BinaryOperator::BitwiseAnd => write!(f, "&"),
61 BinaryOperator::BitwiseXor => write!(f, "^"),
62 }
63 }
64}
65impl WithLocation for BinaryOperator {}
66
67#[derive(Clone, Copy, PartialEq, Debug, Serialize, Deserialize)]
68pub enum UnaryOperator {
69 Sub,
70 Not,
71 BitwiseNot,
72 Dereference,
73 Reference,
74}
75
76impl WithLocation for UnaryOperator {}
77
78impl std::fmt::Display for UnaryOperator {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 UnaryOperator::Sub => write!(f, "-"),
82 UnaryOperator::Not => write!(f, "!"),
83 UnaryOperator::BitwiseNot => write!(f, "~"),
84 UnaryOperator::Dereference => write!(f, "*"),
85 UnaryOperator::Reference => write!(f, "&"),
86 }
87 }
88}
89
90#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
93pub enum NamedArgument<T> {
94 Full(Loc<Identifier>, Loc<T>),
96 Short(Loc<Identifier>, Loc<T>),
98}
99impl<T> WithLocation for NamedArgument<T> {}
100
101#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
104pub enum ArgumentKind {
105 Positional,
106 Named,
107 ShortNamed,
108}
109
110#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
111pub enum ArgumentList<T> {
112 Named(Vec<NamedArgument<T>>),
113 Positional(Vec<Loc<T>>),
114}
115
116impl<T> ArgumentList<T> {
117 pub fn expressions(&self) -> Vec<&Loc<T>> {
118 match self {
119 ArgumentList::Named(n) => n
120 .iter()
121 .map(|arg| match &arg {
122 NamedArgument::Full(_, expr) => expr,
123 NamedArgument::Short(_, expr) => expr,
124 })
125 .collect(),
126 ArgumentList::Positional(arg) => arg.iter().collect(),
127 }
128 }
129 pub fn expressions_mut(&mut self) -> Vec<&mut Loc<T>> {
130 match self {
131 ArgumentList::Named(n) => n
132 .iter_mut()
133 .map(|arg| match arg {
134 NamedArgument::Full(_, expr) => expr,
135 NamedArgument::Short(_, expr) => expr,
136 })
137 .collect(),
138 ArgumentList::Positional(arg) => arg.iter_mut().collect(),
139 }
140 }
141}
142impl<T> WithLocation for ArgumentList<T> {}
143
144#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
145pub struct Argument<T> {
146 pub target: Loc<Identifier>,
147 pub value: Loc<T>,
148 pub kind: ArgumentKind,
149}
150
151#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
153pub enum CallKind {
154 Function,
155 Entity(Loc<()>),
156 Pipeline {
157 inst_loc: Loc<()>,
158 depth: Loc<TypeExpression>,
159 depth_typeexpr_id: ExprID,
162 },
163}
164impl WithLocation for CallKind {}
165
166#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
167pub enum BitLiteral {
168 Low,
169 High,
170 HighImp,
171}
172
173#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
174pub enum IntLiteralKind {
175 Unsized,
176 Signed(BigUint),
177 Unsigned(BigUint),
178}
179
180#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
181pub enum PipelineRefKind {
182 Absolute(Loc<NameID>),
183 Relative(Loc<TypeExpression>),
184}
185impl WithLocation for PipelineRefKind {}
186
187#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
188pub struct CapturedLambdaParam {
189 pub name_in_lambda: NameID,
190 pub name_in_body: Loc<NameID>,
191}
192
193#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
194pub enum ExprKind {
195 Error,
196 Identifier(NameID),
197 IntLiteral(BigInt, IntLiteralKind),
198 BoolLiteral(bool),
199 BitLiteral(BitLiteral),
200 TypeLevelInteger(NameID),
201 CreatePorts,
202 TupleLiteral(Vec<Loc<Expression>>),
203 ArrayLiteral(Vec<Loc<Expression>>),
204 ArrayShorthandLiteral(Box<Loc<Expression>>, Loc<ConstGenericWithId>),
205 Index(Box<Loc<Expression>>, Box<Loc<Expression>>),
206 RangeIndex {
207 target: Box<Loc<Expression>>,
208 start: Loc<ConstGenericWithId>,
209 end: Loc<ConstGenericWithId>,
210 },
211 TupleIndex(Box<Loc<Expression>>, Loc<u128>),
212 FieldAccess(Box<Loc<Expression>>, Loc<Identifier>),
213 MethodCall {
214 target: Box<Loc<Expression>>,
215 name: Loc<Identifier>,
216 args: Loc<ArgumentList<Expression>>,
217 call_kind: CallKind,
218 turbofish: Option<Loc<ArgumentList<TypeExpression>>>,
219 },
220 Call {
221 kind: CallKind,
222 callee: Loc<NameID>,
223 args: Loc<ArgumentList<Expression>>,
224 turbofish: Option<Loc<ArgumentList<TypeExpression>>>,
225 },
226 BinaryOperator(
227 Box<Loc<Expression>>,
228 Loc<BinaryOperator>,
229 Box<Loc<Expression>>,
230 ),
231 UnaryOperator(Loc<UnaryOperator>, Box<Loc<Expression>>),
232 Match(Box<Loc<Expression>>, Vec<(Loc<Pattern>, Loc<Expression>)>),
233 Block(Box<Block>),
234 If(
235 Box<Loc<Expression>>,
236 Box<Loc<Expression>>,
237 Box<Loc<Expression>>,
238 ),
239 TypeLevelIf(
240 Loc<ConstGenericWithId>,
242 Box<Loc<Expression>>,
243 Box<Loc<Expression>>,
244 ),
245 PipelineRef {
246 stage: Loc<PipelineRefKind>,
247 name: Loc<NameID>,
248 declares_name: bool,
249 depth_typeexpr_id: ExprID,
252 },
253 LambdaDef {
254 lambda_type: NameID,
256 lambda_type_params: Vec<Loc<TypeParam>>,
257 captured_generic_params: Vec<CapturedLambdaParam>,
258 lambda_unit: NameID,
260 arguments: Vec<Loc<Pattern>>,
261 body: Box<Loc<Expression>>,
262 },
263 StageValid,
264 StageReady,
265 StaticUnreachable(Loc<String>),
266 Null,
269}
270impl WithLocation for ExprKind {}
271
272impl ExprKind {
273 pub fn with_id(self, id: ExprID) -> Expression {
274 Expression { kind: self, id }
275 }
276
277 pub fn idless(self) -> Expression {
279 Expression {
280 kind: self,
281 id: ExprID(0),
282 }
283 }
284
285 pub fn int_literal(val: i32) -> Self {
286 Self::IntLiteral(val.to_bigint(), IntLiteralKind::Unsized)
287 }
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
291pub struct Expression {
292 pub kind: ExprKind,
293 pub id: ExprID,
295}
296impl WithLocation for Expression {}
297
298impl Expression {
299 pub fn ident(expr_id: ExprID, name_id: u64, name: &str) -> Expression {
302 ExprKind::Identifier(NameID(name_id, Path::from_strs(&[name]))).with_id(expr_id)
303 }
304
305 pub fn assume_block(&self) -> std::result::Result<&Block, ()> {
307 if let ExprKind::Block(ref block) = self.kind {
308 Ok(block)
309 } else if let ExprKind::Error = self.kind {
310 Err(())
311 } else {
312 panic!("Expression is not a block")
313 }
314 }
315
316 pub fn assume_block_mut(&mut self) -> &mut Block {
318 if let ExprKind::Block(block) = &mut self.kind {
319 block.borrow_mut()
320 } else {
321 panic!("Expression is not a block")
322 }
323 }
324}
325
326impl PartialEq for Expression {
327 fn eq(&self, other: &Self) -> bool {
328 self.kind == other.kind
329 }
330}
331
332pub trait LocExprExt {
333 fn runtime_requirement_witness(&self) -> Option<Loc<Expression>>;
334}
335
336impl LocExprExt for Loc<Expression> {
337 fn runtime_requirement_witness(&self) -> Option<Loc<Expression>> {
343 match &self.kind {
344 ExprKind::Error => Some(self.clone()),
345 ExprKind::Identifier(_) => Some(self.clone()),
346 ExprKind::TypeLevelInteger(_) => None,
347 ExprKind::IntLiteral(_, _) => None,
348 ExprKind::BoolLiteral(_) => None,
349 ExprKind::BitLiteral(_) => Some(self.clone()),
350 ExprKind::TupleLiteral(inner) => {
351 inner.iter().find_map(Self::runtime_requirement_witness)
352 }
353 ExprKind::ArrayLiteral(inner) => {
354 inner.iter().find_map(Self::runtime_requirement_witness)
355 }
356 ExprKind::ArrayShorthandLiteral(inner, _) => inner.runtime_requirement_witness(),
357 ExprKind::CreatePorts => Some(self.clone()),
358 ExprKind::Index(l, r) => l
359 .runtime_requirement_witness()
360 .or_else(|| r.runtime_requirement_witness()),
361 ExprKind::RangeIndex { .. } => Some(self.clone()),
362 ExprKind::TupleIndex(l, _) => l.runtime_requirement_witness(),
363 ExprKind::FieldAccess(l, _) => l.runtime_requirement_witness(),
364 ExprKind::MethodCall { .. } | ExprKind::Call { .. } => Some(self.clone()),
367 ExprKind::BinaryOperator(l, operator, r) => {
368 if let Some(witness) = l
369 .runtime_requirement_witness()
370 .or_else(|| r.runtime_requirement_witness())
371 {
372 Some(witness)
373 } else {
374 match &operator.inner {
375 BinaryOperator::Add => None,
376 BinaryOperator::Sub => None,
377 BinaryOperator::Mul
378 | BinaryOperator::Div
379 | BinaryOperator::Mod
380 | BinaryOperator::Eq
381 | BinaryOperator::NotEq
382 | BinaryOperator::Gt
383 | BinaryOperator::Lt
384 | BinaryOperator::Ge
385 | BinaryOperator::Le
386 | BinaryOperator::LeftShift
387 | BinaryOperator::RightShift
388 | BinaryOperator::ArithmeticRightShift
389 | BinaryOperator::LogicalAnd
390 | BinaryOperator::LogicalOr
391 | BinaryOperator::LogicalXor
392 | BinaryOperator::BitwiseOr
393 | BinaryOperator::BitwiseAnd
394 | BinaryOperator::BitwiseXor => Some(self.clone()),
395 }
396 }
397 }
398 ExprKind::UnaryOperator(op, operand) => {
399 if let Some(witness) = operand.runtime_requirement_witness() {
400 Some(witness)
401 } else {
402 match op.inner {
403 UnaryOperator::Sub => None,
404 UnaryOperator::Not
405 | UnaryOperator::BitwiseNot
406 | UnaryOperator::Dereference
407 | UnaryOperator::Reference => Some(self.clone()),
408 }
409 }
410 }
411 ExprKind::Match(_, _) => Some(self.clone()),
412 ExprKind::Block(_) => Some(self.clone()),
413 ExprKind::If(_, _, _) => Some(self.clone()),
414 ExprKind::TypeLevelIf(_, _, _) => Some(self.clone()),
415 ExprKind::PipelineRef { .. } => Some(self.clone()),
416 ExprKind::StageReady => Some(self.clone()),
417 ExprKind::StageValid => Some(self.clone()),
418 ExprKind::LambdaDef { .. } => Some(self.clone()),
419 ExprKind::StaticUnreachable(_) => None,
420 ExprKind::Null => None,
421 }
422 }
423}