blob: 2870b61f515ff15edadf0d8aa328b18aba2b0ef0 [file] [log] [blame]
//! Cost functions for egraph representation.
use crate::ir::Opcode;
/// A cost of computing some value in the program.
///
/// Costs are measured in an arbitrary union that we represent in a
/// `u32`. The ordering is meant to be meaningful, but the value of a
/// single unit is arbitrary (and "not to scale"). We use a collection
/// of heuristics to try to make this approximation at least usable.
///
/// We start by defining costs for each opcode (see `pure_op_cost`
/// below). The cost of computing some value, initially, is the cost
/// of its opcode, plus the cost of computing its inputs.
///
/// We then adjust the cost according to loop nests: for each
/// loop-nest level, we multiply by 1024. Because we only have 32
/// bits, we limit this scaling to a loop-level of two (i.e., multiply
/// by 2^20 ~= 1M).
///
/// Arithmetic on costs is always saturating: we don't want to wrap
/// around and return to a tiny cost when adding the costs of two very
/// expensive operations. It is better to approximate and lose some
/// precision than to lose the ordering by wrapping.
///
/// Finally, we reserve the highest value, `u32::MAX`, as a sentinel
/// that means "infinite". This is separate from the finite costs and
/// not reachable by doing arithmetic on them (even when overflowing)
/// -- we saturate just *below* infinity. (This is done by the
/// `finite()` method.) An infinite cost is used to represent a value
/// that cannot be computed, or otherwise serve as a sentinel when
/// performing search for the lowest-cost representation of a value.
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) struct Cost(u32);
impl core::fmt::Debug for Cost {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if *self == Cost::infinity() {
write!(f, "Cost::Infinite")
} else {
f.debug_struct("Cost::Finite")
.field("op_cost", &self.op_cost())
.field("depth", &self.depth())
.finish()
}
}
}
impl Ord for Cost {
#[inline]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// We make sure that the high bits are the op cost and the low bits are
// the depth. This means that we can use normal integer comparison to
// order by op cost and then depth.
//
// We want to break op cost ties with depth (rather than the other way
// around). When the op cost is the same, we prefer shallow and wide
// expressions to narrow and deep expressions and breaking ties with
// `depth` gives us that. For example, `(a + b) + (c + d)` is preferred
// to `((a + b) + c) + d`. This is beneficial because it exposes more
// instruction-level parallelism and shortens live ranges.
self.0.cmp(&other.0)
}
}
impl PartialOrd for Cost {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Cost {
const DEPTH_BITS: u8 = 8;
const DEPTH_MASK: u32 = (1 << Self::DEPTH_BITS) - 1;
const OP_COST_MASK: u32 = !Self::DEPTH_MASK;
const MAX_OP_COST: u32 = (Self::OP_COST_MASK >> Self::DEPTH_BITS) - 1;
pub(crate) fn infinity() -> Cost {
// 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
// only for heuristics and always saturate so this suffices!)
Cost(u32::MAX)
}
pub(crate) fn zero() -> Cost {
Cost(0)
}
/// Construct a new finite cost from the given parts.
///
/// The opcode cost is clamped to the maximum value representable.
fn new_finite(opcode_cost: u32, depth: u8) -> Cost {
let opcode_cost = std::cmp::min(opcode_cost, Self::MAX_OP_COST);
let cost = Cost((opcode_cost << Self::DEPTH_BITS) | u32::from(depth));
debug_assert_ne!(cost, Cost::infinity());
cost
}
fn depth(&self) -> u8 {
let depth = self.0 & Self::DEPTH_MASK;
u8::try_from(depth).unwrap()
}
fn op_cost(&self) -> u32 {
(self.0 & Self::OP_COST_MASK) >> Self::DEPTH_BITS
}
/// Compute the cost of the operation and its given operands.
///
/// Caller is responsible for checking that the opcode came from an instruction
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
let c = pure_op_cost(op) + operand_costs.into_iter().sum();
Cost::new_finite(c.op_cost(), c.depth().saturating_add(1))
}
}
impl std::iter::Sum<Cost> for Cost {
fn sum<I: Iterator<Item = Cost>>(iter: I) -> Self {
iter.fold(Self::zero(), |a, b| a + b)
}
}
impl std::default::Default for Cost {
fn default() -> Cost {
Cost::zero()
}
}
impl std::ops::Add<Cost> for Cost {
type Output = Cost;
fn add(self, other: Cost) -> Cost {
let op_cost = std::cmp::min(
self.op_cost().saturating_add(other.op_cost()),
Self::MAX_OP_COST,
);
let depth = std::cmp::max(self.depth(), other.depth());
Cost::new_finite(op_cost, depth)
}
}
/// Return the cost of a *pure* opcode.
///
/// Caller is responsible for checking that the opcode came from an instruction
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
fn pure_op_cost(op: Opcode) -> Cost {
match op {
// Constants.
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new_finite(1, 0),
// Extends/reduces.
Opcode::Uextend | Opcode::Sextend | Opcode::Ireduce | Opcode::Iconcat | Opcode::Isplit => {
Cost::new_finite(2, 0)
}
// "Simple" arithmetic.
Opcode::Iadd
| Opcode::Isub
| Opcode::Band
| Opcode::Bor
| Opcode::Bxor
| Opcode::Bnot
| Opcode::Ishl
| Opcode::Ushr
| Opcode::Sshr => Cost::new_finite(3, 0),
// Everything else (pure.)
_ => Cost::new_finite(4, 0),
}
}