Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions executor/common_linux.h
Original file line number Diff line number Diff line change
Expand Up @@ -4415,7 +4415,6 @@ inline int symlink(const char* old_path, const char* new_path)
#define SYSTEM_UID 1000
#define SYSTEM_GID 1000

const char* const SELINUX_CONTEXT_UNTRUSTED_APP = "u:r:untrusted_app:s0:c512,c768";
const char* const SELINUX_LABEL_APP_DATA_FILE = "u:object_r:app_data_file:s0:c512,c768";
const char* const SELINUX_CONTEXT_FILE = "/proc/thread-self/attr/current";
const char* const SELINUX_XATTR_NAME = "security.selinux";
Expand Down Expand Up @@ -4455,7 +4454,7 @@ static void getcon(char* context, size_t context_size)
// - Uses fail() instead of returning an error code
static void setcon(const char* context)
{
char new_context[512];
char new_context[512] = {0};

// Attempt to write the new context
int fd = open(SELINUX_CONTEXT_FILE, O_WRONLY);
Expand All @@ -4470,7 +4469,7 @@ static void setcon(const char* context)
close(fd);

if (bytes_written != (ssize_t)strlen(context))
failmsg("setcon: could not write entire context", "wrote=%zi, expected=%zu", bytes_written, strlen(context));
failmsg("setcon: could not write entire context", "context: %s, wrote=%zi, expected=%zu", context, bytes_written, strlen(context));

// Validate the transition by checking the context
getcon(new_context, sizeof(new_context));
Expand Down Expand Up @@ -4567,8 +4566,6 @@ static int do_sandbox_android(uint64 sandbox_arg)
prctl(PR_SET_PDEATHSIG, SIGKILL, 0, 0, 0);

setfilecon(".", SELINUX_LABEL_APP_DATA_FILE);
if (uid == UNTRUSTED_APP_UID)
setcon(SELINUX_CONTEXT_UNTRUSTED_APP);

loop();
doexit(1);
Expand Down
18 changes: 17 additions & 1 deletion executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ const uint64 instr_eof = -1;
const uint64 instr_copyin = -2;
const uint64 instr_copyout = -3;
const uint64 instr_setprops = -4;
const uint64 instr_seccontext = -5;

const uint64 arg_const = 0;
const uint64 arg_addr32 = 1;
Expand Down Expand Up @@ -970,10 +971,25 @@ void execute_one()
memset(&call_props, 0, sizeof(call_props));

read_input(&input_pos); // total number of calls
for (;;) {
for (int index = 0;; index++) {
uint64 call_num = read_input(&input_pos);
if (call_num == instr_eof)
break;
#if GOOS_linux
if (call_num == instr_seccontext) {
if (index) {
fail("seclabel instruction is not the first call\n");
}
uint64 size = read_input(&input_pos);
char seclabel[64]{};
memcpy(seclabel, input_pos, size);
setcon(seclabel);
input_pos += size;

debug_verbose("applied security label: %s\n", seclabel);
continue;
}
#endif
if (call_num == instr_copyin) {
char* addr = (char*)(read_input(&input_pos) + SYZ_DATA_OFFSET);
uint64 typ = read_input(&input_pos);
Expand Down
219 changes: 216 additions & 3 deletions executor/executor_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <signal.h>
#include <sys/mman.h>
#include <sys/resource.h>
#include <sys/socket.h>
#include <unistd.h>

#include <algorithm>
Expand All @@ -17,6 +18,27 @@
#include <utility>
#include <vector>

#include <linux/audit.h>
#include <linux/netlink.h>

constexpr int NETLINK_BUF_SIZE = 4096;

// Helper function to open a Netlink socket for Audit
int OpenNetlinkAuditSocket()
{
return socket(AF_NETLINK, SOCK_RAW, NETLINK_AUDIT);
}

ssize_t ReceiveNetlinkMessage(int fd, void* buf, size_t len)
{
return recv(fd, buf, len, 0);
}

ssize_t SendNetlinkMessage(int fd, const void* buf, size_t len)
{
return send(fd, buf, len, 0);
}

inline std::ostream& operator<<(std::ostream& ss, const rpc::ExecRequestRawT& req)
{
return ss << "id=" << req.id
Expand Down Expand Up @@ -109,7 +131,7 @@ class Proc
{
public:
Proc(Connection& conn, const char* bin, ProcIDPool& proc_id_pool, int& restarting, const bool& corpus_triaged, int max_signal_fd,
int cover_filter_fd, ProcOpts opts)
int cover_filter_fd, ProcOpts opts, int audit_sock)
: conn_(conn),
bin_(bin),
proc_id_pool_(proc_id_pool),
Expand All @@ -121,7 +143,8 @@ class Proc
opts_(opts),
req_shmem_(kMaxInput),
resp_shmem_(kMaxOutput),
resp_mem_(static_cast<OutputData*>(resp_shmem_.Mem()))
resp_mem_(static_cast<OutputData*>(resp_shmem_.Mem())),
audit_sock_(audit_sock)
{
Start();
}
Expand Down Expand Up @@ -237,6 +260,7 @@ class Proc
uint64 exec_start_ = 0;
uint64 wait_start_ = 0;
uint64 wait_end_ = 0;
int audit_sock_ = 0;

friend std::ostream& operator<<(std::ostream& ss, const Proc& proc)
{
Expand All @@ -251,6 +275,112 @@ class Proc
return ss;
}

ssize_t SendUserAuditMessage(const std::string &message_text) {
const size_t payload_len = message_text.length() + 1;
const size_t buf_len = NLMSG_SPACE(payload_len);
std::vector<char> buf(buf_len);
memset(buf.data(), 0, buf_len);

int fd = OpenNetlinkAuditSocket();
if (fd < 0) {
debug("Failed to open socket to send audit message.\n");
return -1;
}

auto* nlh = reinterpret_cast<struct nlmsghdr*>(buf.data());
nlh->nlmsg_len = NLMSG_LENGTH((int)payload_len);
nlh->nlmsg_type = AUDIT_USER_AVC;
nlh->nlmsg_flags = NLM_F_REQUEST;

char* data = static_cast<char*>(NLMSG_DATA(nlh));
strncpy(data, message_text.c_str(), payload_len);
ssize_t res = SendNetlinkMessage(fd, nlh, nlh->nlmsg_len);
close(fd);
return res;
}

void CollectAuditLogs(std::vector<uint8_t>* output, int64_t req_id)
{
bool prefixed = false;
ssize_t slen = 0;
char buf[NETLINK_BUF_SIZE]{};
struct nlmsghdr* header;
std::string message;

if (SendUserAuditMessage("PROC_END") < 0) {
debug("Failed to send PROC_END. Stopping drain.");
return;
}

while (true) {
slen = ReceiveNetlinkMessage(audit_sock_, buf, sizeof(buf));
if (errno == EINTR) {
continue;
}
if (slen < 0) {
fprintf(stderr, "audit: receive error: %s\n", strerror(errno));
continue;
}
if (slen < NLMSG_LENGTH(0)) {
fprintf(stderr, "audit: message too short\n");
continue;
}
header = (struct nlmsghdr*)buf;
message = std::string((char*)NLMSG_DATA(header),
(char*)NLMSG_DATA(header) +
(slen - sizeof(*header)));
debug("proc %d: req(%ld) - Audit message: %s\n", id_, msg_->id, message.c_str());
if (header->nlmsg_type != AUDIT_USER_AVC) {
continue;
}
if (strstr(message.c_str(), "PROC_START")) {
break;
}
}

// Drain the audit backlog until there is no other message
while (true) {
slen = ReceiveNetlinkMessage(audit_sock_, buf, sizeof(buf));
if (errno == EINTR) {
continue;
}
if (slen < 0) {
fprintf(stderr, "audit: receive error: %s\n", strerror(errno));
continue;
}
if (slen < NLMSG_LENGTH(0)) {
fprintf(stderr, "audit: message too short\n");
continue;
}
header = (struct nlmsghdr*)buf;
message = std::string((char*)NLMSG_DATA(header),
(char*)NLMSG_DATA(header) +
(slen - sizeof(*header)));
debug("proc %d: req(%ld) - Audit message: %s\n", id_, msg_->id, message.c_str());
if (header->nlmsg_type != AUDIT_AVC && header->nlmsg_type != AUDIT_USER_AVC) {
continue;
}
if (header->nlmsg_type == AUDIT_USER_AVC && strstr(message.c_str(), "PROC_END")) {
break;
}
if (header->nlmsg_type == AUDIT_AVC) {
const char *found = strstr(message.c_str(), "syz");
if (!found || !strstr(found, std::to_string(msg_->id).c_str())) {
continue;
}
if (!prefixed) {
char tmp[128];
// Add prefix to the audit messages.
snprintf(tmp, sizeof(tmp), "\nAudit messages:\n");
output->insert(output->end(), tmp, tmp + strlen(tmp));
prefixed = true;
}
message.append("\n");
output->insert(output->end(), message.c_str(), message.c_str() + strlen(message.c_str()));
}
}
}

void ChangeState(State state)
{
if (state_ == State::Handshaking)
Expand Down Expand Up @@ -392,6 +522,12 @@ class Proc

debug("proc %d: start executing request %llu\n", id_, static_cast<uint64>(msg_->id));

if (IsSet(msg_->flags, rpc::RequestFlag::ReturnAudit)) {
if (SendUserAuditMessage("PROC_START") < 0) {
debug("Failed to send PROC START\n");
}
}

rpc::ExecutingMessageRawT exec;
exec.id = msg_->id;
exec.proc_id = id_;
Expand Down Expand Up @@ -454,6 +590,10 @@ class Proc
output_.insert(output_.end(), tmp, tmp + strlen(tmp));
}
}
if (IsSet(msg_->flags, rpc::RequestFlag::ReturnAudit)) {
output = &output_;
CollectAuditLogs(output, msg_->id);
}
uint32 num_calls = 0;
if (msg_->type == rpc::RequestType::Program)
num_calls = read_input(&prog_data);
Expand Down Expand Up @@ -556,9 +696,16 @@ class Runner
proc_id_pool_.emplace(num_procs);
int max_signal_fd = max_signal_ ? max_signal_->FD() : -1;
int cover_filter_fd = cover_filter_ ? cover_filter_->FD() : -1;
int audit_sock = 0;
if (audit) {
audit_sock = registerForAudit();
if (audit_sock < 0) {
debug("Failed to register as the audit sink\n");
}
}
for (int i = 0; i < num_procs; i++)
procs_.emplace_back(new Proc(conn, bin, *proc_id_pool_, restarting_, corpus_triaged_,
max_signal_fd, cover_filter_fd, proc_opts_));
max_signal_fd, cover_filter_fd, proc_opts_, audit_sock));

for (;;)
Loop();
Expand All @@ -574,6 +721,7 @@ class Runner
std::deque<rpc::ExecRequestRawT> requests_;
std::vector<std::string> leak_frames_;
int restarting_ = 0;
bool audit = false;
bool corpus_triaged_ = false;
ProcOpts proc_opts_{};

Expand All @@ -595,6 +743,64 @@ class Runner
return ss;
}

int registerForAudit()
{
struct {
struct nlmsghdr nlh;
struct audit_status status;
} req;
memset(&req, 0, sizeof(req));

int fd = OpenNetlinkAuditSocket();
if (fd < 0) {
return -1;
}

req.nlh.nlmsg_len = NLMSG_LENGTH(sizeof(struct audit_status));
req.nlh.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
req.nlh.nlmsg_seq = 1;
req.nlh.nlmsg_type = AUDIT_SET;
req.status.pid = getpid();
req.status.backlog_limit = 0;
req.status.rate_limit = 0;
req.status.mask = AUDIT_STATUS_PID | AUDIT_STATUS_BACKLOG_LIMIT | AUDIT_STATUS_RATE_LIMIT;

ssize_t sent = SendNetlinkMessage(fd, &req, req.nlh.nlmsg_len);
if (sent != req.nlh.nlmsg_len) {
close(fd);
return -1;
}

ssize_t slen = 0;
char buf[NETLINK_BUF_SIZE];
struct nlmsghdr* header;
do {
slen = ReceiveNetlinkMessage(fd, buf, sizeof(buf));
if (errno == EAGAIN || errno == EINTR) {
continue;
}
if (slen < NLMSG_LENGTH(0)) {
fprintf(stderr, "audit: message too short\n");
continue;
}
header = (struct nlmsghdr*)buf;
} while (header->nlmsg_type != NLMSG_ERROR);

struct nlmsgerr* err;
if ((size_t)slen < NLMSG_LENGTH(sizeof(*err))) {
fprintf(stderr, "audit_listener: error message too short\n");
close(fd);
return -1;
}
err = (struct nlmsgerr*)NLMSG_DATA(header);
if (err->error != 0) {
fprintf(stderr, "audit_listener: received error %d\n", -err->error);
close(fd);
return -1;
}
return fd;
}

void Loop()
{
Select select;
Expand Down Expand Up @@ -661,6 +867,13 @@ class Runner
conn_.Recv(conn_reply);
if (conn_reply.debug)
flag_debug = true;
if (conn_reply.audit) {
if (conn_reply.procs > 1) {
debug("extracting audit logs only supported with one proc");
}
audit = conn_reply.procs == 1;
}

debug("connected to manager: procs=%d cover_edges=%d kernel_64_bit=%d slowdown=%d syscall_timeout=%u"
" program_timeout=%u features=0x%llx\n",
conn_reply.procs, conn_reply.cover_edges, conn_reply.kernel_64_bit,
Expand Down
3 changes: 3 additions & 0 deletions pkg/flatrpc/flatrpc.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ table ConnectRequestRaw {

table ConnectReplyRaw {
debug :bool;
audit :bool;
cover :bool;
cover_edges :bool;
kernel_64_bit :bool;
Expand Down Expand Up @@ -128,6 +129,8 @@ enum RequestType : uint64 {
enum RequestFlag : uint64 (bit_flags) {
// If set, collect program output and return in output field.
ReturnOutput,
// If set, collect audit logs produced by the program at the end of the output.
ReturnAudit,
// If set, don't fail on program failures, instead return the error in error field.
ReturnError,
}
Expand Down
Loading