use std::fs;
use std::env;
use std::hash::{Hash, Hasher};
use std::io::{self, BufRead, Seek};
use std::panic;
use std::process;
use fnv;
use tempfile;
use cmdline;
use error::*;
const OCCURS_ENV: &str = "RUSTY_FORK_OCCURS";
const OCCURS_TERM_LENGTH: usize = 17;
pub fn fork<ID, MODIFIER, PARENT, CHILD, R>(
test_name: &str,
fork_id: ID,
process_modifier: MODIFIER,
in_parent: PARENT,
in_child: CHILD) -> Result<R>
where
ID : Hash,
MODIFIER : FnOnce (&mut process::Command),
PARENT : FnOnce (&mut process::Child, &mut fs::File) -> R,
CHILD : FnOnce ()
{
let fork_id = id_str(fork_id);
let mut return_value = None;
let mut process_modifier = Some(process_modifier);
let mut in_parent = Some(in_parent);
let mut in_child = Some(in_child);
fork_impl(test_name, fork_id,
&mut |cmd| process_modifier.take().unwrap()(cmd),
&mut |child, file| return_value = Some(
in_parent.take().unwrap()(child, file)),
&mut || in_child.take().unwrap()())
.map(|_| return_value.unwrap())
}
fn fork_impl(test_name: &str, fork_id: String,
process_modifier: &mut FnMut (&mut process::Command),
in_parent: &mut FnMut (&mut process::Child, &mut fs::File),
in_child: &mut FnMut ()) -> Result<()> {
let mut occurs = env::var(OCCURS_ENV).unwrap_or_else(|_| String::new());
if occurs.contains(&fork_id) {
match panic::catch_unwind(panic::AssertUnwindSafe(in_child)) {
Ok(_) => process::exit(0),
Err(_) => process::exit(70 ),
}
} else {
if occurs.len() > 16 * OCCURS_TERM_LENGTH {
panic!("rusty-fork: Not forking due to >=16 levels of recursion");
}
let file = tempfile::tempfile()?;
struct KillOnDrop(process::Child, fs::File);
impl Drop for KillOnDrop {
fn drop(&mut self) {
if let Ok(None) = self.0.try_wait() {
let _ = self.0.kill();
}
let _ = self.1.seek(io::SeekFrom::Start(0));
let mut buf = Vec::new();
let mut br = io::BufReader::new(&mut self.1);
loop {
if br.read_until(b'\n', &mut buf).is_err() {
break;
}
if buf.is_empty() {
break;
}
print!("{}", String::from_utf8_lossy(&buf));
buf.clear();
}
}
}
occurs.push_str(&fork_id);
let mut command =
process::Command::new(
env::current_exe()
.expect("current_exe() failed, cannot fork"));
command
.args(cmdline::strip_cmdline(env::args())?)
.args(cmdline::RUN_TEST_ARGS)
.arg(test_name)
.env(OCCURS_ENV, &occurs)
.stdin(process::Stdio::null())
.stdout(file.try_clone()?)
.stderr(file.try_clone()?);
process_modifier(&mut command);
let mut child = command.spawn().map(|p| KillOnDrop(p, file))?;
let ret = in_parent(&mut child.0, &mut child.1);
Ok(ret)
}
}
fn id_str<ID : Hash>(id: ID) -> String {
let mut hasher = fnv::FnvHasher::default();
id.hash(&mut hasher);
return format!(":{:016X}", hasher.finish());
}
#[cfg(test)]
mod test {
use std::io::Read;
use std::thread;
use super::*;
fn sleep(ms: u64) {
thread::sleep(::std::time::Duration::from_millis(ms));
}
fn capturing_output(cmd: &mut process::Command) {
cmd.stdout(process::Stdio::piped())
.stderr(process::Stdio::inherit());
}
fn inherit_output(cmd: &mut process::Command) {
cmd.stdout(process::Stdio::inherit())
.stderr(process::Stdio::inherit());
}
fn wait_for_child_output(child: &mut process::Child,
_file: &mut fs::File) -> String {
let mut output = String::new();
child.stdout.as_mut().unwrap().read_to_string(&mut output).unwrap();
assert!(child.wait().unwrap().success());
output
}
fn wait_for_child(child: &mut process::Child,
_file: &mut fs::File) {
assert!(child.wait().unwrap().success());
}
#[test]
fn fork_basically_works() {
let status =
fork("fork::test::fork_basically_works", rusty_fork_id!(),
|_| (),
|child, _| child.wait().unwrap(),
|| println!("hello from child")).unwrap();
assert!(status.success());
}
#[test]
fn child_output_captured_and_repeated() {
let output = fork(
"fork::test::child_output_captured_and_repeated",
rusty_fork_id!(),
capturing_output, wait_for_child_output,
|| fork(
"fork::test::child_output_captured_and_repeated",
rusty_fork_id!(),
|_| (), wait_for_child,
|| println!("hello from child")).unwrap())
.unwrap();
assert!(output.contains("hello from child"));
}
#[test]
fn child_killed_if_parent_exits_first() {
let output = fork(
"fork::test::child_killed_if_parent_exits_first",
rusty_fork_id!(),
capturing_output, wait_for_child_output,
|| fork(
"fork::test::child_killed_if_parent_exits_first",
rusty_fork_id!(),
inherit_output, |_, _| (),
|| {
sleep(1_000);
println!("hello from child");
}).unwrap()).unwrap();
sleep(2_000);
assert!(!output.contains("hello from child"),
"Had unexpected output:\n{}", output);
}
#[test]
fn child_killed_if_parent_panics_first() {
let output = fork(
"fork::test::child_killed_if_parent_panics_first",
rusty_fork_id!(),
capturing_output, wait_for_child_output,
|| {
assert!(
panic::catch_unwind(panic::AssertUnwindSafe(|| fork(
"fork::test::child_killed_if_parent_panics_first",
rusty_fork_id!(),
inherit_output,
|_, _| panic!("testing a panic, nothing to see here"),
|| {
sleep(1_000);
println!("hello from child");
}).unwrap())).is_err());
}).unwrap();
sleep(2_000);
assert!(!output.contains("hello from child"),
"Had unexpected output:\n{}", output);
}
#[test]
fn child_aborted_if_panics() {
let status = fork(
"fork::test::child_aborted_if_panics",
rusty_fork_id!(),
|_| (),
|child, _| child.wait().unwrap(),
|| panic!("testing a panic, nothing to see here")).unwrap();
assert_eq!(70, status.code().unwrap());
}
}