diff --git a/c2rust-transpile/src/translator/functions.rs b/c2rust-transpile/src/translator/functions.rs index a733550005..7071bff675 100644 --- a/c2rust-transpile/src/translator/functions.rs +++ b/c2rust-transpile/src/translator/functions.rs @@ -401,7 +401,7 @@ impl<'c> Translation<'c> { // Function pointer call _ => { - let mut callee = self.convert_expr(ctx.used(), func, None)?; + let callee = self.convert_expr(ctx.used(), func, None)?; let make_fn_ty = |ret_ty: Box| { let ret_ty = match *ret_ty { Type::Tuple(TypeTuple { elems: ref v, .. }) if v.is_empty() => ReturnType::Default, @@ -419,20 +419,18 @@ impl<'c> Translation<'c> { // K&R function pointer without arguments let ret_ty = self.convert_type(ret_ty.ctype)?; let target_ty = make_fn_ty(ret_ty); - callee.set_unsafe(); callee.map(|fn_ptr| { let fn_ptr = unwrap_function_pointer(fn_ptr); transmute_expr(mk().infer_ty(), target_ty, fn_ptr) - }) + }).set_unsafe() } None => { // We have to infer the return type from our expression type let ret_ty = self.convert_type(call_expr_ty.ctype)?; let target_ty = make_fn_ty(ret_ty); - callee.set_unsafe(); callee.map(|fn_ptr| { transmute_expr(mk().infer_ty(), target_ty, fn_ptr) - }) + }).set_unsafe() } Some(CTypeKind::Function(_, ty_arg_tys, ..)) => { arg_tys = Some(ty_arg_tys.clone()); diff --git a/c2rust-transpile/src/translator/literals.rs b/c2rust-transpile/src/translator/literals.rs index 5bedaffa81..e1bfeca9f3 100644 --- a/c2rust-transpile/src/translator/literals.rs +++ b/c2rust-transpile/src/translator/literals.rs @@ -140,7 +140,7 @@ impl<'c> Translation<'c> { val = mk().const_block_expr(mk().const_block(stmts)); } - Ok(WithStmts::new_unsafe_val(val)) + Ok(WithStmts::new_val(val).set_unsafe()) } } } @@ -183,17 +183,13 @@ impl<'c> Translation<'c> { Ok(val.wrap_unsafe().and_then(|val| { let item = mk().mutbl().static_item(&fresh_name, fresh_ty, val); let fresh_stmt = mk().item_stmt(item); - let mut val = WithStmts::new(vec![fresh_stmt], mk().ident_expr(fresh_name)); - // Accessing a static variable is unsafe. - // In the current nightly, this applies also to taking a raw pointer, - // but this requirement was removed in later versions of the - // `raw_ref_op` feature. - if self.tcfg.edition < Edition2024 { - val.set_unsafe(); - } - - val + WithStmts::new(vec![fresh_stmt], mk().ident_expr(fresh_name)) + // Accessing a static variable is unsafe. + // In the current nightly, this applies also to taking a raw pointer, + // but this requirement was removed in later versions of the + // `raw_ref_op` feature. + .merge_unsafe(self.tcfg.edition < Edition2024) })) } else { Ok(val.and_then(|val| { diff --git a/c2rust-transpile/src/translator/mod.rs b/c2rust-transpile/src/translator/mod.rs index 8d180d4926..29f30b187b 100644 --- a/c2rust-transpile/src/translator/mod.rs +++ b/c2rust-transpile/src/translator/mod.rs @@ -2364,9 +2364,7 @@ impl<'c> Translation<'c> { .span(span) .mutbl() .static_item(&ident2, ty, default_init); - let mut init = init?; - init.set_unsafe(); - let mut init = init.to_expr(); + let mut init = init?.set_unsafe().to_expr(); self.add_static_initializer_to_section(ctx, &ident2, typ, &mut init)?; self.items.borrow_mut()[&self.main_file].add_item(static_item); @@ -3134,9 +3132,7 @@ impl<'c> Translation<'c> { } } - let mut res = WithStmts::new_val(val); - res.merge_unsafe(set_unsafe); - Ok(res) + Ok(WithStmts::new_val(val).merge_unsafe(set_unsafe)) } OffsetOf(ty, ref kind) => match kind { @@ -3258,7 +3254,7 @@ impl<'c> Translation<'c> { if is_explicit { let stmts = self.compute_variable_array_sizes(ctx, ty.ctype)?; - val.prepend_stmts(stmts); + val = val.prepend_stmts(stmts); } // Shuffle Vector "function" builtins will add a cast to the output of the @@ -3325,14 +3321,17 @@ impl<'c> Translation<'c> { let is_unsafe = lhs.is_unsafe() || rhs.is_unsafe(); let then = mk().block(lhs.into_stmts()); let else_ = mk().block_expr(mk().block(rhs.into_stmts())); + let res = cond + .and_then(|c| { + WithStmts::new( + vec![mk().semi_stmt(mk().ifte_expr(c, then, Some(else_)))], + self.panic_or_err( + "Conditional expression is not supposed to be used", + ), + ) + }) + .merge_unsafe(is_unsafe); - let mut res = cond.and_then(|c| { - WithStmts::new( - vec![mk().semi_stmt(mk().ifte_expr(c, then, Some(else_)))], - self.panic_or_err("Conditional expression is not supposed to be used"), - ) - }); - res.merge_unsafe(is_unsafe); Ok(res) } else { let then = lhs.to_block(); @@ -3354,7 +3353,7 @@ impl<'c> Translation<'c> { if ctx.is_unused() { let mut lhs = self.convert_condition(ctx, false, lhs)?; let rhs = self.convert_expr(ctx, rhs, None)?; - lhs.merge_unsafe(rhs.is_unsafe()); + lhs = lhs.merge_unsafe(rhs.is_unsafe()); Ok(lhs.and_then(|val| { WithStmts::new( diff --git a/c2rust-transpile/src/translator/operators.rs b/c2rust-transpile/src/translator/operators.rs index 9ab837aaec..d5047c58dd 100644 --- a/c2rust-transpile/src/translator/operators.rs +++ b/c2rust-transpile/src/translator/operators.rs @@ -422,11 +422,11 @@ impl<'c> Translation<'c> { let assign_stmt = match op { // Regular (possibly volatile) assignment Assign if !is_volatile => WithStmts::new_val(mk().assign_expr(write, rhs)), - Assign => WithStmts::new_unsafe_val(self.volatile_write( + Assign => WithStmts::new_val(self.volatile_write( write, initial_lhs_type_id, rhs, - )?), + )?).set_unsafe(), // Anything volatile needs to be desugared into explicit reads and writes op if is_volatile || is_unsigned_arith => { @@ -473,9 +473,9 @@ impl<'c> Translation<'c> { #[allow(clippy::let_and_return /* , reason = "block is large, so variable name helps" */)] let write = if is_volatile { val.and_then_try(|val| { - TranslationResult::Ok(WithStmts::new_unsafe_val( + TranslationResult::Ok(WithStmts::new_val( self.volatile_write(write, initial_lhs_type_id, val)?, - )) + ).set_unsafe()) })? } else { val.map(|val| mk().assign_expr(write, val)) @@ -617,7 +617,7 @@ impl<'c> Translation<'c> { offset = mk().binary_expr(BinOp::Div(Default::default()), offset, div); } - Ok(WithStmts::new_unsafe_val(mk().cast_expr(offset, ty))) + Ok(WithStmts::new_val(mk().cast_expr(offset, ty)).set_unsafe()) } else if let &CTypeKind::Pointer(pointee) = lhs_type { Ok(self.convert_pointer_offset(lhs, rhs, pointee.ctype, true, false)) } else if lhs_type.is_unsigned_integral_type() { @@ -772,13 +772,12 @@ impl<'c> Translation<'c> { mk().assign_expr(write, val) }; - let mut val = WithStmts::new( + let val = WithStmts::new( vec![save_old_val, mk().expr_stmt(assign_stmt)], mk().ident_expr(val_name), - ); - if is_unsafe { - val.set_unsafe(); - } + ) + .merge_unsafe(is_unsafe); + Ok(val) }, ) diff --git a/c2rust-transpile/src/translator/pointers.rs b/c2rust-transpile/src/translator/pointers.rs index 6ae8b1ce35..a15fef518c 100644 --- a/c2rust-transpile/src/translator/pointers.rs +++ b/c2rust-transpile/src/translator/pointers.rs @@ -382,7 +382,7 @@ impl<'c> Translation<'c> { res = mk().unary_expr(UnOp::Deref(Default::default()), res); } - WithStmts::new_unsafe_val(res) + WithStmts::new_val(res).set_unsafe() } /// Construct an expression for a NULL at any type, including forward declarations, @@ -450,7 +450,7 @@ impl<'c> Translation<'c> { self.import_type(target_cty); Ok(val.and_then(|val| { - WithStmts::new_unsafe_val(transmute_expr(source_ty, target_ty, val)) + WithStmts::new_val(transmute_expr(source_ty, target_ty, val)).set_unsafe() })) } else { // Normal case @@ -483,7 +483,7 @@ impl<'c> Translation<'c> { let intptr_t = mk().abs_path_ty(vec!["libc", "intptr_t"]); val = mk().cast_expr(val, intptr_t.clone()); - WithStmts::new_unsafe_val(transmute_expr(intptr_t, target_ty, val)) + WithStmts::new_val(transmute_expr(intptr_t, target_ty, val)).set_unsafe() })) } else if source_ty_kind.is_bool() { self.use_crate(ExternCrate::Libc); @@ -520,7 +520,7 @@ impl<'c> Translation<'c> { if self.ast_context.is_function_pointer(source_cty) { Ok(val.and_then(|val| { - WithStmts::new_unsafe_val(transmute_expr(source_ty, target_ty, val)) + WithStmts::new_val(transmute_expr(source_ty, target_ty, val)).set_unsafe() })) } else if let &CTypeKind::Enum(enum_decl_id) = target_ty_kind { val.try_map(|val| self.convert_cast_to_enum(ctx, target_cty, enum_decl_id, expr, val)) diff --git a/c2rust-transpile/src/translator/simd.rs b/c2rust-transpile/src/translator/simd.rs index 91fee866a1..cca06186e9 100644 --- a/c2rust-transpile/src/translator/simd.rs +++ b/c2rust-transpile/src/translator/simd.rs @@ -270,11 +270,10 @@ impl<'c> Translation<'c> { let n_bytes_expr = mk().lit_expr(mk().int_lit(bytes, "")); let expr = mk().repeat_expr(zero_expr, n_bytes_expr); - Ok(WithStmts::new_unsafe_val(transmute_expr( - mk().infer_ty(), - mk().infer_ty(), - expr, - ))) + Ok( + WithStmts::new_val(transmute_expr(mk().infer_ty(), mk().infer_ty(), expr)) + .set_unsafe(), + ) } else { self.import_simd_function(fn_name) .expect("None of these fns should be unsupported in rust"); @@ -334,7 +333,7 @@ impl<'c> Translation<'c> { mk().call_expr(mk().ident_expr(fn_call_name), params) }; - let mut val = if ctx.is_used() { + let val = if ctx.is_used() { WithStmts::new_val(call) } else { WithStmts::new( @@ -342,9 +341,8 @@ impl<'c> Translation<'c> { self.panic_or_err("No value for unused shuffle vector return"), ) }; - val.merge_unsafe(is_unsafe); - Ok(val) + Ok(val.merge_unsafe(is_unsafe)) }) } diff --git a/c2rust-transpile/src/translator/structs_unions.rs b/c2rust-transpile/src/translator/structs_unions.rs index 5eafa93305..db21177a4d 100644 --- a/c2rust-transpile/src/translator/structs_unions.rs +++ b/c2rust-transpile/src/translator/structs_unions.rs @@ -541,12 +541,7 @@ impl<'a> Translation<'a> { stmts.push(mk().expr_stmt(mk().ident_expr("init"))); let val = mk().block_expr(mk().block(stmts)); - - if is_unsafe { - WithStmts::new_unsafe_val(val) - } else { - WithStmts::new_val(val) - } + WithStmts::new_val(val).merge_unsafe(is_unsafe) }); } diff --git a/c2rust-transpile/src/with_stmts.rs b/c2rust-transpile/src/with_stmts.rs index cc4e588dec..1eb5221f04 100644 --- a/c2rust-transpile/src/with_stmts.rs +++ b/c2rust-transpile/src/with_stmts.rs @@ -27,14 +27,6 @@ impl WithStmts { } } - pub fn new_unsafe_val(val: T) -> Self { - WithStmts { - stmts: vec![], - val, - is_unsafe: true, - } - } - pub fn and_then(self, f: F) -> WithStmts where F: FnOnce(T) -> WithStmts, @@ -95,12 +87,14 @@ impl WithStmts { } } - pub fn set_unsafe(&mut self) { + pub fn set_unsafe(mut self) -> Self { self.is_unsafe = true; + self } - pub fn merge_unsafe(&mut self, is_unsafe: bool) { + pub fn merge_unsafe(mut self, is_unsafe: bool) -> Self { self.is_unsafe = self.is_unsafe || is_unsafe; + self } pub fn into_stmts(self) -> Vec { @@ -142,13 +136,15 @@ impl WithStmts { self.is_unsafe } - pub fn add_stmt(&mut self, stmt: Stmt) { + pub fn add_stmt(mut self, stmt: Stmt) -> Self { self.stmts.push(stmt); + self } - pub fn prepend_stmts(&mut self, mut stmts: Vec) { + pub fn prepend_stmts(mut self, mut stmts: Vec) -> Self { stmts.append(&mut self.stmts); self.stmts = stmts; + self } pub fn is_pure(&self) -> bool { @@ -211,8 +207,6 @@ impl FromIterator> for WithStmts> { stmts.append(val.stmts_mut()); res.push(val.into_value()); } - let mut translation = WithStmts::new(stmts, res); - translation.merge_unsafe(is_unsafe); - translation + WithStmts::new(stmts, res).merge_unsafe(is_unsafe) } }