spade_hir/
expression.rs

1use std::borrow::BorrowMut;
2
3use crate::{ConstGenericWithId, Pattern, TypeExpression, TypeParam};
4
5use super::{Block, NameID};
6use num::{BigInt, BigUint};
7use serde::{Deserialize, Serialize};
8use spade_common::{
9    id_tracker::ExprID,
10    location_info::{Loc, WithLocation},
11    name::{Identifier, Path},
12    num_ext::InfallibleToBigInt,
13};
14
15#[derive(Clone, Copy, PartialEq, Debug, Serialize, Deserialize)]
16pub enum BinaryOperator {
17    Add,
18    Sub,
19    Mul,
20    Div,
21    Mod,
22    Eq,
23    NotEq,
24    Gt,
25    Lt,
26    Ge,
27    Le,
28    LeftShift,
29    RightShift,
30    ArithmeticRightShift,
31    LogicalAnd,
32    LogicalOr,
33    LogicalXor,
34    BitwiseOr,
35    BitwiseAnd,
36    BitwiseXor,
37}
38
39impl std::fmt::Display for BinaryOperator {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            BinaryOperator::Add => write!(f, "+"),
43            BinaryOperator::Sub => write!(f, "-"),
44            BinaryOperator::Mul => write!(f, "*"),
45            BinaryOperator::Div => write!(f, "/"),
46            BinaryOperator::Mod => write!(f, "%"),
47            BinaryOperator::Eq => write!(f, "=="),
48            BinaryOperator::NotEq => write!(f, "!="),
49            BinaryOperator::Gt => write!(f, ">"),
50            BinaryOperator::Lt => write!(f, "<"),
51            BinaryOperator::Ge => write!(f, ">="),
52            BinaryOperator::Le => write!(f, "<="),
53            BinaryOperator::LeftShift => write!(f, ">>"),
54            BinaryOperator::RightShift => write!(f, "<<"),
55            BinaryOperator::ArithmeticRightShift => write!(f, ">>>"),
56            BinaryOperator::LogicalAnd => write!(f, "&&"),
57            BinaryOperator::LogicalOr => write!(f, "||"),
58            BinaryOperator::LogicalXor => write!(f, "^^"),
59            BinaryOperator::BitwiseOr => write!(f, "|"),
60            BinaryOperator::BitwiseAnd => write!(f, "&"),
61            BinaryOperator::BitwiseXor => write!(f, "^"),
62        }
63    }
64}
65impl WithLocation for BinaryOperator {}
66
67#[derive(Clone, Copy, PartialEq, Debug, Serialize, Deserialize)]
68pub enum UnaryOperator {
69    Sub,
70    Not,
71    BitwiseNot,
72    Dereference,
73    Reference,
74}
75
76impl WithLocation for UnaryOperator {}
77
78impl std::fmt::Display for UnaryOperator {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        match self {
81            UnaryOperator::Sub => write!(f, "-"),
82            UnaryOperator::Not => write!(f, "!"),
83            UnaryOperator::BitwiseNot => write!(f, "~"),
84            UnaryOperator::Dereference => write!(f, "*"),
85            UnaryOperator::Reference => write!(f, "&"),
86        }
87    }
88}
89
90// Named arguments are used for both type parameters in turbofishes and in argument lists. T is the
91// right hand side of a binding, i.e. an expression in an argument list
92#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
93pub enum NamedArgument<T> {
94    /// Binds the arguent named LHS in the outer scope to the expression
95    Full(Loc<Identifier>, Loc<T>),
96    /// Binds a local variable to an argument with the same name
97    Short(Loc<Identifier>, Loc<T>),
98}
99impl<T> WithLocation for NamedArgument<T> {}
100
101/// Specifies how an argument is bound. Mainly used for error reporting without
102/// code duplication
103#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
104pub enum ArgumentKind {
105    Positional,
106    Named,
107    ShortNamed,
108}
109
110#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
111pub enum ArgumentList<T> {
112    Named(Vec<NamedArgument<T>>),
113    Positional(Vec<Loc<T>>),
114}
115
116impl<T> ArgumentList<T> {
117    pub fn expressions(&self) -> Vec<&Loc<T>> {
118        match self {
119            ArgumentList::Named(n) => n
120                .iter()
121                .map(|arg| match &arg {
122                    NamedArgument::Full(_, expr) => expr,
123                    NamedArgument::Short(_, expr) => expr,
124                })
125                .collect(),
126            ArgumentList::Positional(arg) => arg.iter().collect(),
127        }
128    }
129    pub fn expressions_mut(&mut self) -> Vec<&mut Loc<T>> {
130        match self {
131            ArgumentList::Named(n) => n
132                .iter_mut()
133                .map(|arg| match arg {
134                    NamedArgument::Full(_, expr) => expr,
135                    NamedArgument::Short(_, expr) => expr,
136                })
137                .collect(),
138            ArgumentList::Positional(arg) => arg.iter_mut().collect(),
139        }
140    }
141}
142impl<T> WithLocation for ArgumentList<T> {}
143
144#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
145pub struct Argument<T> {
146    pub target: Loc<Identifier>,
147    pub value: Loc<T>,
148    pub kind: ArgumentKind,
149}
150
151// FIXME: Migrate entity, pipeline and fn instantiation to this
152#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
153pub enum CallKind {
154    Function,
155    Entity(Loc<()>),
156    Pipeline {
157        inst_loc: Loc<()>,
158        depth: Loc<TypeExpression>,
159        /// An expression ID for which the type inferer will infer the depth of the instantiated
160        /// pipeline, i.e. inst(<this>)
161        depth_typeexpr_id: ExprID,
162    },
163}
164impl WithLocation for CallKind {}
165
166#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
167pub enum BitLiteral {
168    Low,
169    High,
170    HighImp,
171}
172
173#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
174pub enum IntLiteralKind {
175    Unsized,
176    Signed(BigUint),
177    Unsigned(BigUint),
178}
179
180#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
181pub enum PipelineRefKind {
182    Absolute(Loc<NameID>),
183    Relative(Loc<TypeExpression>),
184}
185impl WithLocation for PipelineRefKind {}
186
187#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
188pub struct CapturedLambdaParam {
189    pub name_in_lambda: NameID,
190    pub name_in_body: Loc<NameID>,
191}
192
193#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
194pub enum ExprKind {
195    Error,
196    Identifier(NameID),
197    IntLiteral(BigInt, IntLiteralKind),
198    BoolLiteral(bool),
199    BitLiteral(BitLiteral),
200    TypeLevelInteger(NameID),
201    CreatePorts,
202    TupleLiteral(Vec<Loc<Expression>>),
203    ArrayLiteral(Vec<Loc<Expression>>),
204    ArrayShorthandLiteral(Box<Loc<Expression>>, Loc<ConstGenericWithId>),
205    Index(Box<Loc<Expression>>, Box<Loc<Expression>>),
206    RangeIndex {
207        target: Box<Loc<Expression>>,
208        start: Loc<ConstGenericWithId>,
209        end: Loc<ConstGenericWithId>,
210    },
211    TupleIndex(Box<Loc<Expression>>, Loc<u128>),
212    FieldAccess(Box<Loc<Expression>>, Loc<Identifier>),
213    MethodCall {
214        target: Box<Loc<Expression>>,
215        name: Loc<Identifier>,
216        args: Loc<ArgumentList<Expression>>,
217        call_kind: CallKind,
218        turbofish: Option<Loc<ArgumentList<TypeExpression>>>,
219    },
220    Call {
221        kind: CallKind,
222        callee: Loc<NameID>,
223        args: Loc<ArgumentList<Expression>>,
224        turbofish: Option<Loc<ArgumentList<TypeExpression>>>,
225    },
226    BinaryOperator(
227        Box<Loc<Expression>>,
228        Loc<BinaryOperator>,
229        Box<Loc<Expression>>,
230    ),
231    UnaryOperator(Loc<UnaryOperator>, Box<Loc<Expression>>),
232    Match(Box<Loc<Expression>>, Vec<(Loc<Pattern>, Loc<Expression>)>),
233    Block(Box<Block>),
234    If(
235        Box<Loc<Expression>>,
236        Box<Loc<Expression>>,
237        Box<Loc<Expression>>,
238    ),
239    TypeLevelIf(
240        // FIXME: Having a random u64 is not great, let's make TypeExpressions always have associated ids
241        Loc<ConstGenericWithId>,
242        Box<Loc<Expression>>,
243        Box<Loc<Expression>>,
244    ),
245    PipelineRef {
246        stage: Loc<PipelineRefKind>,
247        name: Loc<NameID>,
248        declares_name: bool,
249        /// An expression ID which after typeinference will contain the absolute depth
250        /// of this referenced value
251        depth_typeexpr_id: ExprID,
252    },
253    LambdaDef {
254        /// The type that this lambda definition creates
255        lambda_type: NameID,
256        lambda_type_params: Vec<Loc<TypeParam>>,
257        captured_generic_params: Vec<CapturedLambdaParam>,
258        /// The unit which is the `call` method on this lambda
259        lambda_unit: NameID,
260        arguments: Vec<Loc<Pattern>>,
261        body: Box<Loc<Expression>>,
262    },
263    StageValid,
264    StageReady,
265    StaticUnreachable(Loc<String>),
266    // This is a special case expression which is never created in user code, but which can be used
267    // in type inference to create virtual expressions with specific IDs
268    Null,
269}
270impl WithLocation for ExprKind {}
271
272impl ExprKind {
273    pub fn with_id(self, id: ExprID) -> Expression {
274        Expression { kind: self, id }
275    }
276
277    // FIXME: These really should be #[cfg(test)]'d away
278    pub fn idless(self) -> Expression {
279        Expression {
280            kind: self,
281            id: ExprID(0),
282        }
283    }
284
285    pub fn int_literal(val: i32) -> Self {
286        Self::IntLiteral(val.to_bigint(), IntLiteralKind::Unsized)
287    }
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
291pub struct Expression {
292    pub kind: ExprKind,
293    // This ID is used to associate types with the expression
294    pub id: ExprID,
295}
296impl WithLocation for Expression {}
297
298impl Expression {
299    /// Create a new expression referencing an identifier with the specified
300    /// id and name
301    pub fn ident(expr_id: ExprID, name_id: u64, name: &str) -> Expression {
302        ExprKind::Identifier(NameID(name_id, Path::from_strs(&[name]))).with_id(expr_id)
303    }
304
305    /// Returns the block that is this expression if it is a block, an error if it is an Error node, and panics if the expression is not a block or error
306    pub fn assume_block(&self) -> std::result::Result<&Block, ()> {
307        if let ExprKind::Block(ref block) = self.kind {
308            Ok(block)
309        } else if let ExprKind::Error = self.kind {
310            Err(())
311        } else {
312            panic!("Expression is not a block")
313        }
314    }
315
316    /// Returns the block that is this expression. Panics if the expression is not a block
317    pub fn assume_block_mut(&mut self) -> &mut Block {
318        if let ExprKind::Block(block) = &mut self.kind {
319            block.borrow_mut()
320        } else {
321            panic!("Expression is not a block")
322        }
323    }
324}
325
326impl PartialEq for Expression {
327    fn eq(&self, other: &Self) -> bool {
328        self.kind == other.kind
329    }
330}
331
332pub trait LocExprExt {
333    fn runtime_requirement_witness(&self) -> Option<Loc<Expression>>;
334}
335
336impl LocExprExt for Loc<Expression> {
337    /// Checks if the expression is evaluable at compile time, returning a Loc of
338    /// a (sub)-expression which requires runtime, and None if it is comptime evaluable.
339    ///
340    /// If this method returns None, `.eval()` on the resulting list of mir statements is
341    /// guaranteed to work
342    fn runtime_requirement_witness(&self) -> Option<Loc<Expression>> {
343        match &self.kind {
344            ExprKind::Error => Some(self.clone()),
345            ExprKind::Identifier(_) => Some(self.clone()),
346            ExprKind::TypeLevelInteger(_) => None,
347            ExprKind::IntLiteral(_, _) => None,
348            ExprKind::BoolLiteral(_) => None,
349            ExprKind::BitLiteral(_) => Some(self.clone()),
350            ExprKind::TupleLiteral(inner) => {
351                inner.iter().find_map(Self::runtime_requirement_witness)
352            }
353            ExprKind::ArrayLiteral(inner) => {
354                inner.iter().find_map(Self::runtime_requirement_witness)
355            }
356            ExprKind::ArrayShorthandLiteral(inner, _) => inner.runtime_requirement_witness(),
357            ExprKind::CreatePorts => Some(self.clone()),
358            ExprKind::Index(l, r) => l
359                .runtime_requirement_witness()
360                .or_else(|| r.runtime_requirement_witness()),
361            ExprKind::RangeIndex { .. } => Some(self.clone()),
362            ExprKind::TupleIndex(l, _) => l.runtime_requirement_witness(),
363            ExprKind::FieldAccess(l, _) => l.runtime_requirement_witness(),
364            // NOTE: We probably shouldn't see this here since we'll have lowered
365            // methods at this point, but this function doesn't throw
366            ExprKind::MethodCall { .. } | ExprKind::Call { .. } => Some(self.clone()),
367            ExprKind::BinaryOperator(l, operator, r) => {
368                if let Some(witness) = l
369                    .runtime_requirement_witness()
370                    .or_else(|| r.runtime_requirement_witness())
371                {
372                    Some(witness)
373                } else {
374                    match &operator.inner {
375                        BinaryOperator::Add => None,
376                        BinaryOperator::Sub => None,
377                        BinaryOperator::Mul
378                        | BinaryOperator::Div
379                        | BinaryOperator::Mod
380                        | BinaryOperator::Eq
381                        | BinaryOperator::NotEq
382                        | BinaryOperator::Gt
383                        | BinaryOperator::Lt
384                        | BinaryOperator::Ge
385                        | BinaryOperator::Le
386                        | BinaryOperator::LeftShift
387                        | BinaryOperator::RightShift
388                        | BinaryOperator::ArithmeticRightShift
389                        | BinaryOperator::LogicalAnd
390                        | BinaryOperator::LogicalOr
391                        | BinaryOperator::LogicalXor
392                        | BinaryOperator::BitwiseOr
393                        | BinaryOperator::BitwiseAnd
394                        | BinaryOperator::BitwiseXor => Some(self.clone()),
395                    }
396                }
397            }
398            ExprKind::UnaryOperator(op, operand) => {
399                if let Some(witness) = operand.runtime_requirement_witness() {
400                    Some(witness)
401                } else {
402                    match op.inner {
403                        UnaryOperator::Sub => None,
404                        UnaryOperator::Not
405                        | UnaryOperator::BitwiseNot
406                        | UnaryOperator::Dereference
407                        | UnaryOperator::Reference => Some(self.clone()),
408                    }
409                }
410            }
411            ExprKind::Match(_, _) => Some(self.clone()),
412            ExprKind::Block(_) => Some(self.clone()),
413            ExprKind::If(_, _, _) => Some(self.clone()),
414            ExprKind::TypeLevelIf(_, _, _) => Some(self.clone()),
415            ExprKind::PipelineRef { .. } => Some(self.clone()),
416            ExprKind::StageReady => Some(self.clone()),
417            ExprKind::StageValid => Some(self.clone()),
418            ExprKind::LambdaDef { .. } => Some(self.clone()),
419            ExprKind::StaticUnreachable(_) => None,
420            ExprKind::Null => None,
421        }
422    }
423}