Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions client/transport/ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ static int priskv_ucx_handshake(priskv_transport_conn *conn, ucp_address_t **add
uint32_t *address_len)
{
int ret;
ucs_status_t status;
uint8_t *peer_worker_address = NULL;

size_t hs_size = sizeof(priskv_cm_ucx_handshake) + conn->worker->address_len;
Expand Down Expand Up @@ -445,19 +446,21 @@ static int priskv_ucx_handshake(priskv_transport_conn *conn, ucp_address_t **add
hs->cap.max_inflight_command = htobe16(conn->param.max_inflight_command);
hs->address_len = htobe32(conn->worker->address_len);
memcpy(hs->address, conn->worker->address, conn->worker->address_len);
ret = priskv_safe_send(conn->connfd, hs, hs_size, NULL, NULL);
status = ucs_socket_send(conn->connfd, hs, hs_size);
free(hs);
if (ret) {
priskv_log_error("UCX: failed to send capability to server\n");
if (status != UCS_OK) {
priskv_log_error("UCX: failed to send capability to server, status: %s\n",
ucs_status_string(status));
ret = -1;
goto error;
}

/* receive response from server */
priskv_cm_ucx_handshake peer_hs;
ret = priskv_safe_recv(conn->connfd, &peer_hs, sizeof(peer_hs), NULL, NULL);
if (ret) {
priskv_log_error("UCX: failed to receive handshake msg from server\n");
status = ucs_socket_recv(conn->connfd, &peer_hs, sizeof(peer_hs));
if (status != UCS_OK) {
priskv_log_error("UCX: failed to receive handshake msg from server, status: %s\n",
ucs_status_string(status));
ret = -1;
goto error;
}
Expand Down Expand Up @@ -490,10 +493,10 @@ static int priskv_ucx_handshake(priskv_transport_conn *conn, ucp_address_t **add
ret = -1;
goto error;
}
ret = priskv_safe_recv(conn->connfd, peer_worker_address, peer_worker_address_len, NULL,
NULL);
if (ret) {
priskv_log_error("UCX: failed to receive peer_worker_address from server\n");
status = ucs_socket_recv(conn->connfd, peer_worker_address, peer_worker_address_len);
if (status != UCS_OK) {
priskv_log_error("UCX: failed to receive peer_worker_address from server, status: %s\n",
ucs_status_string(status));
ret = -1;
goto error;
}
Expand Down
49 changes: 0 additions & 49 deletions include/priskv-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,55 +148,6 @@ static inline void priskv_inet_ntop(struct sockaddr *addr, char *dst)
}
}

static inline int priskv_sock_io(int sock, ssize_t (*sock_call)(int, void *, size_t, int),
int poll_events, void *data, size_t size,
void (*progress)(void *arg), void *arg, const char *name)
{
size_t total = 0;
struct pollfd pfd;
int ret;

while (total < size) {
pfd.fd = sock;
pfd.events = poll_events;
pfd.revents = 0;

ret = poll(&pfd, 1, 1); /* poll for 1ms */
if (ret > 0) {
ret = sock_call(sock, (char *)data + total, size - total, 0);
if ((ret == 0) && (poll_events & POLLIN)) {
return -1;
}
if (ret < 0) {
return -1;
}
total += ret;
} else if ((ret < 0) && (errno != EINTR)) {
return -1;
}

/* progress user context */
if (progress != NULL) {
progress(arg);
}
}
return 0;
}

static inline int priskv_safe_send(int sock, void *data, size_t size, void (*progress)(void *arg),
void *arg)
{
typedef ssize_t (*sock_call)(int, void *, size_t, int);

return priskv_sock_io(sock, (sock_call)send, POLLOUT, data, size, progress, arg, "send");
}

static inline int priskv_safe_recv(int sock, void *data, size_t size, void (*progress)(void *arg),
void *arg)
{
return priskv_sock_io(sock, recv, POLLIN, data, size, progress, arg, "recv");
}

static inline unsigned long priskv_rdtsc(void)
{
unsigned long low, high;
Expand Down
40 changes: 24 additions & 16 deletions server/transport/ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,16 @@ static inline void priskv_ucx_reject(priskv_transport_conn *client, priskv_cm_st
.status = htobe16(status),
.value = htobe64(value),
};
int ret = priskv_safe_send(client->connfd, &rej_msg_be, sizeof(rej_msg_be), NULL, NULL);
if (ret < 0) {
priskv_log_error("UCX: send reject message failed: %m\n");
ucs_status_t ucs_status = ucs_socket_send(client->connfd, &rej_msg_be, sizeof(rej_msg_be));
if (ucs_status != UCS_OK) {
priskv_log_error("UCX: send reject message failed, status: %s\n",
ucs_status_string(ucs_status));
}
}

static inline int priskv_ucx_accept(priskv_transport_conn *client)
{
int ret = 0;
uint32_t address_len = client->worker->address_len;
size_t hs_size = sizeof(priskv_cm_ucx_handshake) + address_len;
priskv_cm_ucx_handshake *hs = malloc(hs_size);
Expand Down Expand Up @@ -273,9 +275,11 @@ static inline int priskv_ucx_accept(priskv_transport_conn *client)
client->peer_addr, address_len, print_len, worker_address_hex);
}

int ret = priskv_safe_send(client->connfd, hs, hs_size, NULL, NULL);
if (ret < 0) {
priskv_log_error("UCX: send accept message failed: %m\n");
ucs_status_t status = ucs_socket_send(client->connfd, hs, hs_size);
if (status != UCS_OK) {
ret = -1;
priskv_log_error("UCX: send accept message failed, status: %s\n",
ucs_status_string(status));
goto out_free_msg;
}

Expand All @@ -290,6 +294,7 @@ static inline int priskv_ucx_accept(priskv_transport_conn *client)
static inline int priskv_ucx_handle_handshake(void *arg)
{
int ret;
ucs_status_t sock_status;
priskv_cm_ucx_handshake peer_hs;
priskv_cm_status status;

Expand All @@ -299,11 +304,12 @@ static inline int priskv_ucx_handle_handshake(void *arg)
int connfd = client->connfd;

/* #step0, recv handshake msg */
ret = priskv_safe_recv(connfd, &peer_hs, sizeof(peer_hs), NULL, NULL);
if (ret < 0) {
priskv_log_error("UCX: recv handshake msg failed: %m\n");
sock_status = ucs_socket_recv(connfd, &peer_hs, sizeof(peer_hs));
if (sock_status != UCS_OK) {
priskv_log_error("UCX: recv handshake msg failed, status: %s\n",
ucs_status_string(sock_status));
ucs_close_fd(&connfd);
return;
return -1;
}

client->conn_cap.version = be16toh(peer_hs.cap.version);
Expand All @@ -317,13 +323,14 @@ static inline int priskv_ucx_handle_handshake(void *arg)
if (!peer_worker_address) {
priskv_log_error("UCX: malloc peer address failed: %m\n");
ucs_close_fd(&connfd);
return;
return -1;
}
ret = priskv_safe_recv(connfd, peer_worker_address, peer_worker_address_len, NULL, NULL);
if (ret < 0) {
priskv_log_error("UCX: recv peer address failed: %m\n");
sock_status = ucs_socket_recv(connfd, peer_worker_address, peer_worker_address_len);
if (sock_status != UCS_OK) {
priskv_log_error("UCX: recv peer address failed, status: %s\n",
ucs_status_string(sock_status));
ucs_close_fd(&connfd);
return;
return -1;
}
}

Expand Down Expand Up @@ -437,7 +444,7 @@ static inline int priskv_ucx_handle_handshake(void *arg)

priskv_log_notice("UCX: <%s - %s> established\n", client->local_addr, client->peer_addr);

return;
return 0;

rej:
priskv_log_warn("UCX: <%s - %s> %s, reject\n", client->local_addr, client->peer_addr,
Expand All @@ -448,6 +455,7 @@ static inline int priskv_ucx_handle_handshake(void *arg)
}
priskv_ucx_reject(client, status, value);
priskv_ucx_mark_client_closed(client);
return -1;
}

static inline void priskv_ucx_handle_cm(int fd, void *opaque, uint32_t ev)
Expand Down
Loading