Skip to content

Commit

Permalink
fix exit_status propagation, handle long nvcc language option (--x)
Browse files Browse the repository at this point in the history
  • Loading branch information
trxcllnt committed Oct 17, 2024
1 parent 6251699 commit 78f9ae0
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 25 deletions.
22 changes: 14 additions & 8 deletions src/compiler/cicc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,20 @@ where
);
continue;
}
Some(UnhashedInput(o)) => {
Some(UnhashedInputOrOutput(o)) => {
take_next = false;
let path = cwd.join(o);
if !path.exists() {
continue;
if path.exists() {
extra_inputs.push(path);
} else if let Some(flag) = arg.flag_str() {
outputs.insert(
flag,
ArtifactDescriptor {
path,
optional: false,
},
);
}
extra_inputs.push(path);
&mut unhashed_args
}
Some(UnhashedOutput(o)) => {
Expand All @@ -163,7 +170,7 @@ where
}
&mut unhashed_args
}
Some(UnhashedFlag) | Some(Unhashed(_)) => {
Some(UnhashedFlag) => {
take_next = false;
&mut unhashed_args
}
Expand Down Expand Up @@ -294,11 +301,10 @@ pub fn generate_compile_commands(

ArgData! { pub
Output(PathBuf),
UnhashedInput(PathBuf),
UnhashedInputOrOutput(PathBuf),
UnhashedOutput(PathBuf),
UnhashedFlag,
PassThrough(OsString),
Unhashed(OsString),
}

use self::ArgData::*;
Expand All @@ -308,7 +314,7 @@ counted_array!(pub static ARGS: [ArgInfo<ArgData>; _] = [
take_arg!("--gen_device_file_name", PathBuf, Separated, UnhashedOutput),
flag!("--gen_module_id_file", UnhashedFlag),
take_arg!("--include_file_name", OsString, Separated, PassThrough),
take_arg!("--module_id_file_name", PathBuf, Separated, UnhashedInput),
take_arg!("--module_id_file_name", PathBuf, Separated, UnhashedInputOrOutput),
take_arg!("--stub_file_name", PathBuf, Separated, UnhashedOutput),
take_arg!("-o", PathBuf, Separated, Output),
]);
55 changes: 40 additions & 15 deletions src/compiler/nvcc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ use std::collections::HashMap;
use std::ffi::{OsStr, OsString};
use std::future::{Future, IntoFuture};
use std::io::{self, BufRead, Read, Write};
#[cfg(unix)]
use std::os::unix::process::ExitStatusExt;
use std::path::{Path, PathBuf};
use std::process;
use which::which_in;
Expand Down Expand Up @@ -508,9 +510,9 @@ impl CompileCommandImpl for NvccCompileCommand {
};

let n = nvcc_subcommand_groups.len();
let cuda_front_end_range = if n < 1 { 0..0 } else { 0..1 };
let device_compile_range = if n < 2 { 0..0 } else { 1..n - 1 };
let final_assembly_range = if n < 3 { 0..0 } else { n - 1..n };
let cuda_front_end_range = if n > 0 { 0..1 } else { 0..0 };
let final_assembly_range = if n > 1 { n - 1..n } else { 0..0 };
let device_compile_range = if n > 2 { 1..n - 1 } else { 0..0 };

let num_parallel = device_compile_range.len().min(*num_parallel).max(1);

Expand All @@ -531,12 +533,7 @@ impl CompileCommandImpl for NvccCompileCommand {
output = aggregate_output(output, result.unwrap_or_else(error_to_output));
}

if output
.status
.code()
.and_then(|c| (c != 0).then_some(c))
.is_some()
{
if !output.status.success() {
output.stdout.shrink_to_fit();
output.stderr.shrink_to_fit();
maybe_keep_temps_then_clean();
Expand Down Expand Up @@ -657,6 +654,7 @@ where
trace!(
"transformed nvcc command: {:?}",
[
&[format!("cd {} &&", dir.to_string_lossy()).to_string()],
&[exe.to_str().unwrap_or_default().to_string()][..],
&args[..]
]
Expand Down Expand Up @@ -1051,14 +1049,14 @@ where
.await
}
}
.map_or_else(error_to_output, compile_result_to_output),
.map_or_else(error_to_output, |res| compile_result_to_output(exe, res)),
}
}
};

output = aggregate_output(output, out);

if output.status.code().unwrap_or(0) != 0 {
if !output.status.success() {
break;
}
}
Expand All @@ -1069,8 +1067,8 @@ where
fn aggregate_output(lhs: process::Output, rhs: process::Output) -> process::Output {
process::Output {
status: exit_status(std::cmp::max(
lhs.status.code().unwrap_or(0),
rhs.status.code().unwrap_or(0),
status_to_code(lhs.status),
status_to_code(rhs.status),
) as ExitStatusValue),
stdout: [lhs.stdout, rhs.stdout].concat(),
stderr: [lhs.stderr, rhs.stderr].concat(),
Expand All @@ -1088,14 +1086,40 @@ fn error_to_output(err: Error) -> process::Output {
}
}

fn compile_result_to_output(res: protocol::CompileFinished) -> process::Output {
fn compile_result_to_output(exe: &Path, res: protocol::CompileFinished) -> process::Output {
if let Some(signal) = res.signal {
return process::Output {
status: exit_status(signal as ExitStatusValue),
stdout: res.stdout,
stderr: [
format!(
"{} terminated (signal: {})",
exe.file_stem().unwrap().to_string_lossy(),
signal
)
.as_bytes(),
&res.stderr,
]
.concat(),
};
}
process::Output {
status: exit_status(res.retcode.or(res.signal).unwrap_or(0) as ExitStatusValue),
status: exit_status(res.retcode.unwrap_or(0) as ExitStatusValue),
stdout: res.stdout,
stderr: res.stderr,
}
}

fn status_to_code(res: process::ExitStatus) -> ExitStatusValue {
if res.success() {
0
} else if !cfg!(unix) {
res.code().unwrap_or(1)
} else {
res.signal().or(res.code()).unwrap_or(1)
}
}

counted_array!(pub static ARGS: [ArgInfo<gcc::ArgData>; _] = [
//todo: refactor show_includes into dependency_args
take_arg!("--Werror", OsString, CanBeSeparated('='), PreprocessorArgument),
Expand Down Expand Up @@ -1130,6 +1154,7 @@ counted_array!(pub static ARGS: [ArgInfo<gcc::ArgData>; _] = [
flag!("--save-temps", UnhashedFlag),
take_arg!("--system-include", PathBuf, CanBeSeparated('='), PreprocessorArgumentPath),
take_arg!("--threads", OsString, CanBeSeparated('='), Unhashed),
take_arg!("--x", OsString, CanBeSeparated('='), Language),

take_arg!("-Werror", OsString, CanBeSeparated('='), PreprocessorArgument),
take_arg!("-Xarchive", OsString, CanBeSeparated('='), PassThrough),
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/ptxas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl CCompilerImpl for Ptxas {
use cicc::ArgData::*;

counted_array!(pub static ARGS: [ArgInfo<cicc::ArgData>; _] = [
take_arg!("-arch", OsString, CanBeSeparated, PassThrough),
take_arg!("-m", OsString, CanBeSeparated, PassThrough),
take_arg!("-arch", OsString, CanBeSeparated('='), PassThrough),
take_arg!("-m", OsString, CanBeSeparated('='), PassThrough),
take_arg!("-o", PathBuf, Separated, Output),
]);

0 comments on commit 78f9ae0

Please sign in to comment.