diff --git a/contrib/mir/src/ast/comparable.rs b/contrib/mir/src/ast/comparable.rs index b455fc741d4193930bde02a2489b645400b7caeb..d3552427c51d49b707bd7895a2a6ada9b9527fb5 100644 --- a/contrib/mir/src/ast/comparable.rs +++ b/contrib/mir/src/ast/comparable.rs @@ -5,17 +5,40 @@ impl PartialOrd for TypedValue { use TypedValue::*; match (self, other) { (Int(a), Int(b)) => a.partial_cmp(b), + (Int(..), _) => None, + (Nat(a), Nat(b)) => a.partial_cmp(b), + (Nat(..), _) => None, + (Mutez(a), Mutez(b)) => a.partial_cmp(b), + (Mutez(..), _) => None, + (Bool(a), Bool(b)) => a.partial_cmp(b), + (Bool(..), _) => None, + (String(a), String(b)) => a.partial_cmp(b), + (String(..), _) => None, + (Unit, Unit) => Some(std::cmp::Ordering::Equal), + (Unit, _) => None, + (Pair(l), Pair(r)) => l.partial_cmp(r), + (Pair(..), _) => None, + (Option(x), Option(y)) => x.as_deref().partial_cmp(&y.as_deref()), + (Option(..), _) => None, + (Or(x), Or(y)) => x.as_ref().partial_cmp(y.as_ref()), + (Or(..), _) => None, + (Address(l), Address(r)) => l.partial_cmp(r), + (Address(..), _) => None, + (ChainId(l), ChainId(r)) => l.partial_cmp(r), - _ => None, + (ChainId(..), _) => None, + + // non-comparable types + (List(..) | Map(..) | Contract(..), _) => None, } } } diff --git a/contrib/mir/src/gas.rs b/contrib/mir/src/gas.rs index d65bffb4d48d6c46c13167590523f889bdf1f912..33b740c0b47d78e65a7e08345618f6738200ca9a 100644 --- a/contrib/mir/src/gas.rs +++ b/contrib/mir/src/gas.rs @@ -152,7 +152,7 @@ pub mod interpret_cost { use checked::Checked; use super::{AsGasCost, OutOfGas}; - use crate::ast::TypedValue; + use crate::ast::{Or, TypedValue}; pub const DIP: u32 = 10; pub const DROP: u32 = 10; @@ -246,20 +246,39 @@ pub mod interpret_cost { let cmp_option = Checked::from(10u32); const ADDRESS_SIZE: usize = 20 + 31; // hash size + max entrypoint size const CMP_CHAIN_ID: u32 = 30; + let cmp_or = Checked::from(10u32); + #[track_caller] + fn incomparable() -> ! { + unreachable!("Comparison of incomparable values") + } Ok(match (v1, v2) { (V::Nat(l), V::Nat(r)) => { // NB: eventually when using BigInts, use BigInt::bits() &c cmp_bytes(std::mem::size_of_val(l), std::mem::size_of_val(r))? } + (V::Nat(_), _) => incomparable(), + (V::Int(l), V::Int(r)) => { // NB: eventually when using BigInts, use BigInt::bits() &c cmp_bytes(std::mem::size_of_val(l), std::mem::size_of_val(r))? } + (V::Int(_), _) => incomparable(), + (V::Bool(_), V::Bool(_)) => cmp_bytes(1, 1)?, + (V::Bool(_), _) => incomparable(), + (V::Mutez(_), V::Mutez(_)) => cmp_bytes(8, 8)?, + (V::Mutez(_), _) => incomparable(), + (V::String(l), V::String(r)) => cmp_bytes(l.len(), r.len())?, + (V::String(_), _) => incomparable(), + (V::Unit, V::Unit) => 10, + (V::Unit, _) => incomparable(), + (V::Pair(l), V::Pair(r)) => cmp_pair(l.as_ref(), r.as_ref())?, + (V::Pair(_), _) => incomparable(), + (V::Option(l), V::Option(r)) => match (l, r) { (None, None) => cmp_option, (None, Some(_)) => cmp_option, @@ -267,9 +286,24 @@ pub mod interpret_cost { (Some(l), Some(r)) => cmp_option + compare(l, r)?, } .as_gas_cost()?, + (V::Option(_), _) => incomparable(), + (V::Address(..), V::Address(..)) => cmp_bytes(ADDRESS_SIZE, ADDRESS_SIZE)?, + (V::Address(_), _) => incomparable(), + (V::ChainId(..), V::ChainId(..)) => CMP_CHAIN_ID, - _ => unreachable!("Comparison of incomparable values"), + (V::ChainId(_), _) => incomparable(), + + (V::Or(l), V::Or(r)) => match (l.as_ref(), r.as_ref()) { + (Or::Left(x), Or::Left(y)) => cmp_or + compare(x, y)?, + (Or::Right(x), Or::Right(y)) => cmp_or + compare(x, y)?, + (Or::Left(_), Or::Right(_)) => cmp_or, + (Or::Right(_), Or::Left(_)) => cmp_or, + } + .as_gas_cost()?, + (V::Or(..), _) => incomparable(), + + (V::List(..) | V::Map(..) | V::Contract(_), _) => incomparable(), }) }