From 5813d9ea4fc1c4adfdb1f324657292b85fdcb552 Mon Sep 17 00:00:00 2001 From: Emma Turner Date: Fri, 11 Apr 2025 20:09:50 +0100 Subject: [PATCH] RISC-V: switch block dispatch to use C fn pointer directly --- src/riscv/lib/src/jit.rs | 102 +++++++++--------- src/riscv/lib/src/jit/builder.rs | 10 +- .../src/machine_state/block_cache/block.rs | 80 ++++++-------- 3 files changed, 92 insertions(+), 100 deletions(-) diff --git a/src/riscv/lib/src/jit.rs b/src/riscv/lib/src/jit.rs index ebe4a6659a73..ecd08d97e2e2 100644 --- a/src/riscv/lib/src/jit.rs +++ b/src/riscv/lib/src/jit.rs @@ -9,6 +9,7 @@ mod builder; pub mod state_access; use std::collections::HashMap; +use std::ffi::c_void; use cranelift::codegen::CodegenError; use cranelift::codegen::ir::types::I64; @@ -35,49 +36,29 @@ use crate::state_backend::hash::Hash; use crate::traps::EnvironException; /// Alias for the function signature produced by the JIT compilation. -type JitFn = unsafe extern "C" fn( +/// +/// This must have the same Abi as [`Dispatch`], which is used by +/// the block dispatch mechanism in the block cache. +/// +/// The JitFn does not inspect the first and last parameters here, however. +/// These parameters are needed by the initial dispatch mechanism to enable +/// JIT-compilation & hot-swapping. To avoid over-specifying these parameters here +/// (which can among other things cause type-checking issues), we replace the parameters +/// with pointers to `c_void` - which in the C abi map to the same parameter type as the +/// thin-references to the actual variables passed. +/// +/// [`Dispatch`]: crate::machine_state::block_cache::block::Dispatch +pub type JitFn = unsafe extern "C" fn( + // ignored + *const c_void, &mut MachineCoreState, u64, &mut usize, &mut Result<(), EnvironException>, + // ignored + *const c_void, ); -/// A jit-compiled function that can be [called] over [`MachineCoreState`]. -/// -/// [called]: Self::call -pub struct JCall { - fun: JitFn, -} - -impl JCall { - /// Run the jit-compiled function over the state. - /// - /// # Safety - /// - /// When calling, the [JIT] that compiled this function *must* - /// still be alive. - pub unsafe fn call( - &self, - core: &mut MachineCoreState, - pc: u64, - steps: &mut usize, - ) -> Result<(), EnvironException> { - let mut res = Ok(()); - - unsafe { - (self.fun)(core, pc, steps, &mut res); - } - - res - } -} - -impl Clone for JCall { - fn clone(&self) -> Self { - Self { fun: self.fun } - } -} - /// Errors that may arise from the initialisation of the JIT. #[derive(Debug, Error)] pub enum JitError { @@ -115,7 +96,7 @@ pub struct JIT { jsa_imports: JsaImports, /// Cache of compilation results. - cache: HashMap>>, + cache: HashMap>>, } impl JIT { @@ -155,13 +136,13 @@ impl JIT { /// /// Not all instructions are currently supported. For blocks containing /// unsupported instructions, `None` will be returned. - pub fn compile(&mut self, instr: &[Instruction]) -> Option> { + pub fn compile(&mut self, instr: &[Instruction]) -> Option> { let Ok(hash) = Hash::blake2b_hash(instr) else { return None; }; if let Some(compilation_result) = self.cache.get(&hash) { - return compilation_result.clone(); + return *compilation_result; } let mut builder = self.start(); @@ -170,6 +151,7 @@ impl JIT { let Some(lower) = i.opcode.to_lowering() else { builder.fail(); self.clear(); + self.cache.insert(hash, None); return None; }; @@ -213,10 +195,15 @@ impl JIT { fn start(&mut self) -> Builder<'_, MC, JSA> { let ptr = self.module.target_config().pointer_type(); + // first param ignored + self.ctx.func.signature.params.push(AbiParam::new(ptr)); + // params self.ctx.func.signature.params.push(AbiParam::new(ptr)); self.ctx.func.signature.params.push(AbiParam::new(I64)); self.ctx.func.signature.params.push(AbiParam::new(ptr)); self.ctx.func.signature.params.push(AbiParam::new(ptr)); + // last param ignored + self.ctx.func.signature.params.push(AbiParam::new(ptr)); let builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); let jsa_call = JsaCalls::func_calls(&mut self.module, &self.jsa_imports, ptr); @@ -225,16 +212,15 @@ impl JIT { } /// Finalise and cache the function under construction. - fn produce_function(&mut self, hash: &Hash) -> JCall { + fn produce_function(&mut self, hash: &Hash) -> JitFn { let name = hex::encode(hash); let fun = self.finalise(&name); - let jcall = JCall { fun }; - self.cache.insert(*hash, Some(jcall.clone())); + self.cache.insert(*hash, Some(fun)); block_metrics!(hash = hash, record_jitted); - JCall { fun } + fun } /// Finalise the function currently under construction. @@ -277,6 +263,8 @@ impl Default for JIT { #[cfg(test)] mod tests { + use std::ptr::null; + use Instruction as I; use super::*; @@ -381,10 +369,19 @@ mod tests { interpreted_bb, ) }; - let jitted_res = unsafe { + + let mut jitted_res = Ok(()); + unsafe { // # Safety - the block builder is alive for at least // the duration of the `run` function. - fun.call(&mut jitted, initial_pc, &mut jitted_steps) + (fun)( + null(), + &mut jitted, + initial_pc, + &mut jitted_steps, + &mut jitted_res, + null(), + ) }; // Assert state equality. @@ -1629,10 +1626,19 @@ mod tests { let fun = jit .compile(instructions(&block).as_slice()) .expect("Compilation of subsequent functions should succeed"); - let jitted_res = unsafe { + + let mut jitted_res = Ok(()); + unsafe { // # Safety - the jit is not dropped until after we // exit the block. - fun.call(&mut jitted, initial_pc, &mut jitted_steps) + (fun)( + null(), + &mut jitted, + initial_pc, + &mut jitted_steps, + &mut jitted_res, + null(), + ) }; assert!(jitted_res.is_ok()); diff --git a/src/riscv/lib/src/jit/builder.rs b/src/riscv/lib/src/jit/builder.rs index 6c0fd5c25c7f..8a931d1cfefa 100644 --- a/src/riscv/lib/src/jit/builder.rs +++ b/src/riscv/lib/src/jit/builder.rs @@ -89,10 +89,12 @@ impl<'a, MC: MemoryConfig, JSA: JitStateAccess> Builder<'a, MC, JSA> { builder.switch_to_block(entry_block); builder.seal_block(entry_block); - let core_ptr_val = builder.block_params(entry_block)[0]; - let pc_val = X64(builder.block_params(entry_block)[1]); - let steps_ptr_val = builder.block_params(entry_block)[2]; - let result_ptr_val = builder.block_params(entry_block)[3]; + // first param ignored + let core_ptr_val = builder.block_params(entry_block)[1]; + let pc_val = X64(builder.block_params(entry_block)[2]); + let steps_ptr_val = builder.block_params(entry_block)[3]; + let result_ptr_val = builder.block_params(entry_block)[4]; + // last param ignored Self { ptr, diff --git a/src/riscv/lib/src/machine_state/block_cache/block.rs b/src/riscv/lib/src/machine_state/block_cache/block.rs index 050b3f317669..f9ba819e74c4 100644 --- a/src/riscv/lib/src/machine_state/block_cache/block.rs +++ b/src/riscv/lib/src/machine_state/block_cache/block.rs @@ -11,8 +11,8 @@ use super::CACHE_INSTR; use super::ICallPlaced; use super::run_instr; use crate::default::ConstDefault; -use crate::jit::JCall; use crate::jit::JIT; +use crate::jit::JitFn; use crate::jit::state_access::JitStateAccess; use crate::machine_state::MachineCoreState; use crate::machine_state::ProgramCounterUpdate; @@ -248,13 +248,14 @@ impl Clone for Interpreted { /// /// Internally, this may be interpreted, just-in-time compiled, or do /// additional work over just execution. -type Dispatch = unsafe fn( +pub type Dispatch = unsafe extern "C" fn( &mut InlineJit, &mut MachineCoreState, Address, &mut usize, + &mut Result<(), EnvironException>, &mut as Block>::BlockBuilder, -) -> Result<(), EnvironException>; +); /// Blocks that are compiled to native code for execution, when possible. /// @@ -264,7 +265,6 @@ type Dispatch = unsafe fn( /// Blocks are compiled upon calling [`Block::run_block`], in a *stop the world* fashion. pub struct InlineJit { fallback: Interpreted, - jit_fn: Option>, dispatch: Dispatch, } @@ -280,13 +280,14 @@ impl InlineJit { /// /// This ensures that the builder in question is guaranteed to be alive, for at least as long /// as this block may be run via [`Block::run_block`]. - unsafe fn run_block_interpreted( + unsafe extern "C" fn run_block_interpreted( &mut self, core: &mut MachineCoreState, instr_pc: Address, steps: &mut usize, + result: &mut Result<(), EnvironException>, block_builder: &mut >::BlockBuilder, - ) -> Result<(), EnvironException> { + ) { // trigger JIT compilation let instr = self .fallback @@ -296,17 +297,22 @@ impl InlineJit { .map(|i| i.read_stored()) .collect::>(); - self.jit_fn = block_builder.0.compile(&instr); + let fun = match block_builder.0.compile(&instr) { + Some(jitfn) => { + // Safety: the two function signatures are identical, apart from the first and + // last parameters. These are both pointers, and ignored by the JitFn. + // + // It's therefore safe to cast these to thin-pointers to any type. + unsafe { std::mem::transmute::, Dispatch>(jitfn) } + } + None => Self::run_block_not_compiled, + }; - if self.jit_fn.is_some() { - self.dispatch = Self::run_block_compiled; - } else { - self.dispatch = Self::run_block_not_compiled; - } + self.dispatch = fun; // Safety: the block builder passed to this function is always the same for the // lifetime of the block - unsafe { (self.dispatch)(self, core, instr_pc, steps, block_builder) } + unsafe { (fun)(self, core, instr_pc, steps, result, block_builder) } } /// Run a block where JIT-compilation has been attempted, but failed for any reason. @@ -317,40 +323,19 @@ impl InlineJit { /// /// This ensures that the builder in question is guaranteed to be alive, for at least as long /// as this block may be run via [`Block::run_block`]. - unsafe fn run_block_not_compiled( + unsafe extern "C" fn run_block_not_compiled( &mut self, core: &mut MachineCoreState, instr_pc: Address, steps: &mut usize, + result: &mut Result<(), EnvironException>, block_builder: &mut >::BlockBuilder, - ) -> Result<(), EnvironException> { - // Safety: this function is always safe to call - unsafe { + ) { + *result = unsafe { + // Safety: this function is always safe to call self.fallback .run_block(core, instr_pc, steps, &mut block_builder.1) - } - } - - /// Run a block using the result of JIT-compilation, where this has been successful. - /// - /// # SAFETY - /// - /// The `block_builder` must be the same every time this function is called. - /// - /// This ensures that the builder in question is guaranteed to be alive, for at least as long - /// as this block may be run via [`Block::run_block`]. - unsafe fn run_block_compiled( - &mut self, - core: &mut MachineCoreState, - instr_pc: Address, - steps: &mut usize, - _block_builder: &mut >::BlockBuilder, - ) -> Result<(), EnvironException> { - let fun = self.jit_fn.as_ref().unwrap(); - - // Safety: the block builder passed to this function is always the same for the - // lifetime of the block - unsafe { fun.call(core, instr_pc, steps) } + }; } } @@ -361,7 +346,6 @@ impl NewState for InlineJit { { Self { fallback: Interpreted::new(manager), - jit_fn: None, dispatch: Self::run_block_interpreted, } } @@ -375,7 +359,6 @@ impl Block for InlineJit { M: ManagerWrite, { self.dispatch = Self::run_block_interpreted; - self.jit_fn = None; self.fallback.start_block() } @@ -384,7 +367,6 @@ impl Block for InlineJit { M: ManagerWrite, { self.dispatch = Self::run_block_interpreted; - self.jit_fn = None; self.fallback.invalidate() } @@ -393,7 +375,6 @@ impl Block for InlineJit { M: ManagerReadWrite, { self.dispatch = Self::run_block_interpreted; - self.jit_fn = None; self.fallback.reset() } @@ -402,7 +383,6 @@ impl Block for InlineJit { M: ManagerReadWrite, { self.dispatch = Self::run_block_interpreted; - self.jit_fn = None; self.fallback.push_instr(instr) } @@ -416,7 +396,6 @@ impl Block for InlineJit { fn bind(allocated: AllocatedOf) -> Self { Self { fallback: Interpreted::bind(allocated), - jit_fn: None, dispatch: Self::run_block_interpreted, } } @@ -443,7 +422,13 @@ impl Block for InlineJit { where M: ManagerReadWrite, { - unsafe { (self.dispatch)(self, core, instr_pc, steps, block_builder) } + let mut result = Ok(()); + + // Safety: the block builder is always the same instance, guarantee-ing that any + // jit-compiled function is still alive. + unsafe { (self.dispatch)(self, core, instr_pc, steps, &mut result, block_builder) }; + + result } fn num_instr(&self) -> usize @@ -458,7 +443,6 @@ impl Clone for InlineJit Self { Self { fallback: self.fallback.clone(), - jit_fn: None, dispatch: Self::run_block_interpreted, } } -- GitLab