From 82b5d21a12c71f527ff4cd3a511fe7e5084106db Mon Sep 17 00:00:00 2001 From: Tom Deseyn Date: Fri, 22 Jul 2022 05:47:04 +0200 Subject: [PATCH] Add callbacks for channel open response, and channel request response. Signed-off-by: Tom Deseyn --- include/libssh/callbacks.h | 31 +++ src/channels.c | 28 +++ tests/client/CMakeLists.txt | 1 + tests/client/torture_client_callbacks.c | 250 ++++++++++++++++++++++++ 4 files changed, 310 insertions(+) create mode 100644 tests/client/torture_client_callbacks.c diff --git a/include/libssh/callbacks.h b/include/libssh/callbacks.h index b97522f4f..a753b712c 100644 --- a/include/libssh/callbacks.h +++ b/include/libssh/callbacks.h @@ -27,6 +27,7 @@ #include #include +#include #ifdef __cplusplus extern "C" { @@ -780,6 +781,28 @@ typedef int (*ssh_channel_write_wontblock_callback) (ssh_session session, uint32_t bytes, void *userdata); +/** + * @brief SSH channel open callback. Called when a channel open succeeds or fails. + * @param session Current session handler + * @param channel the actual channel + * @param is_success is 1 when the open succeeds, and 0 otherwise. + * @param userdata Userdata to be passed to the callback function. + */ +typedef void (*ssh_channel_open_resp_callback) (ssh_session session, + ssh_channel channel, + bool is_success, + void *userdata); + +/** + * @brief SSH channel request response callback. Called when a response to the pending request is received. + * @param session Current session handler + * @param channel the actual channel + * @param userdata Userdata to be passed to the callback function. + */ +typedef void (*ssh_channel_request_resp_callback) (ssh_session session, + ssh_channel channel, + void *userdata); + struct ssh_channel_callbacks_struct { /** DON'T SET THIS use ssh_callbacks_init() instead. */ size_t size; @@ -847,6 +870,14 @@ struct ssh_channel_callbacks_struct { * not to block. */ ssh_channel_write_wontblock_callback channel_write_wontblock_function; + /** + * This functions will be called when the channel has received a channel open confirmation or failure. + */ + ssh_channel_open_resp_callback channel_open_response_function; + /** + * This functions will be called when the channel has received the response to the pending request. + */ + ssh_channel_request_resp_callback channel_request_response_function; }; typedef struct ssh_channel_callbacks_struct *ssh_channel_callbacks; diff --git a/src/channels.c b/src/channels.c index e718458cd..268553415 100644 --- a/src/channels.c +++ b/src/channels.c @@ -212,6 +212,14 @@ SSH_PACKET_CALLBACK(ssh_packet_channel_open_conf){ channel->state = SSH_CHANNEL_STATE_OPEN; channel->flags &= ~SSH_CHANNEL_FLAG_NOT_BOUND; + + ssh_callbacks_execute_list(channel->callbacks, + ssh_channel_callbacks, + channel_open_response_function, + channel->session, + channel, + true /* is_success */); + return SSH_PACKET_USED; error: @@ -261,6 +269,14 @@ SSH_PACKET_CALLBACK(ssh_packet_channel_open_fail){ error); SAFE_FREE(error); channel->state=SSH_CHANNEL_STATE_OPEN_DENIED; + + ssh_callbacks_execute_list(channel->callbacks, + ssh_channel_callbacks, + channel_open_response_function, + channel->session, + channel, + false /* is_success */); + return SSH_PACKET_USED; error: @@ -1713,6 +1729,12 @@ SSH_PACKET_CALLBACK(ssh_packet_channel_success){ channel->request_state); } else { channel->request_state=SSH_CHANNEL_REQ_STATE_ACCEPTED; + + ssh_callbacks_execute_list(channel->callbacks, + ssh_channel_callbacks, + channel_request_response_function, + channel->session, + channel); } return SSH_PACKET_USED; @@ -1744,6 +1766,12 @@ SSH_PACKET_CALLBACK(ssh_packet_channel_failure){ channel->request_state); } else { channel->request_state=SSH_CHANNEL_REQ_STATE_DENIED; + + ssh_callbacks_execute_list(channel->callbacks, + ssh_channel_callbacks, + channel_request_response_function, + channel->session, + channel); } return SSH_PACKET_USED; diff --git a/tests/client/CMakeLists.txt b/tests/client/CMakeLists.txt index 71e5182e2..4772efc96 100644 --- a/tests/client/CMakeLists.txt +++ b/tests/client/CMakeLists.txt @@ -4,6 +4,7 @@ find_package(socket_wrapper) set(LIBSSH_CLIENT_TESTS torture_algorithms + torture_client_callbacks torture_client_config torture_connect torture_hostkey diff --git a/tests/client/torture_client_callbacks.c b/tests/client/torture_client_callbacks.c new file mode 100644 index 000000000..3b9b85a35 --- /dev/null +++ b/tests/client/torture_client_callbacks.c @@ -0,0 +1,250 @@ +/* + * This file is part of the SSH Library + * + * Copyright (c) 2012 by Aris Adamantiadis + * + * The SSH Library is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation; either version 2.1 of the License, or (at your + * option) any later version. + * + * The SSH Library is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public + * License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with the SSH Library; see the file COPYING. If not, write to + * the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, + * MA 02111-1307, USA. + */ + +#include "config.h" + +#define LIBSSH_STATIC + +#include "torture.h" +#include "libssh/libssh.h" +#include "libssh/priv.h" +#include "libssh/session.h" +#include "libssh/callbacks.h" + +#include +#include +#include + +#define STATE_SUCCESS (1) +#define STATE_FAILURE (2) + +struct callback_state +{ + int open_response; + int request_response; + ssh_session expected_session; + ssh_channel expected_channel; +}; + +static void on_open_response(ssh_session session, ssh_channel channel, bool is_success, void *userdata) +{ + struct callback_state *state = (struct callback_state*)userdata; + assert_ptr_equal(state->expected_session, session); + assert_ptr_equal(state->expected_channel, channel); + state->open_response = is_success ? STATE_SUCCESS : STATE_FAILURE; +} + +static void on_request_response(ssh_session session, ssh_channel channel, void *userdata) +{ + struct callback_state *state = (struct callback_state*)userdata; + assert_ptr_equal(state->expected_session, session); + assert_ptr_equal(state->expected_channel, channel); + state->request_response = STATE_SUCCESS; +} + +static struct callback_state *set_callbacks(ssh_session session, ssh_channel channel) +{ + int rc; + struct ssh_channel_callbacks_struct *cb; + struct callback_state *cb_state = NULL; + + cb_state = (struct callback_state *)calloc(1, + sizeof(struct callback_state)); + assert_non_null(cb_state); + cb_state->expected_session = session; + cb_state->expected_channel = channel; + + cb = (struct ssh_channel_callbacks_struct *)calloc(1, + sizeof(struct ssh_channel_callbacks_struct)); + assert_non_null(cb); + ssh_callbacks_init(cb); + cb->userdata = cb_state; + cb->channel_open_response_function = on_open_response; + cb->channel_request_response_function = on_request_response; + rc = ssh_set_channel_callbacks(channel, cb); + assert_ssh_return_code(session, rc); + + return cb_state; +} + +static int sshd_setup(void **state) +{ + torture_setup_sshd_server(state, false); + + return 0; +} + +static int sshd_teardown(void **state) +{ + torture_teardown_sshd_server(state); + + return 0; +} + +static int session_setup(void **state) +{ + struct torture_state *s = *state; + struct passwd *pwd; + int rc; + + pwd = getpwnam("bob"); + assert_non_null(pwd); + + rc = setuid(pwd->pw_uid); + assert_return_code(rc, errno); + + s->ssh.session = torture_ssh_session(s, + TORTURE_SSH_SERVER, + NULL, + TORTURE_SSH_USER_ALICE, + NULL); + assert_non_null(s->ssh.session); + + return 0; +} + +static int session_teardown(void **state) +{ + struct torture_state *s = *state; + + ssh_disconnect(s->ssh.session); + ssh_free(s->ssh.session); + + return 0; +} + +static void torture_open_success(void **state) +{ + struct torture_state *s = *state; + ssh_session session = s->ssh.session; + ssh_channel channel; + int rc; + struct callback_state *cb_state = NULL; + + channel = ssh_channel_new(session); + assert_non_null(channel); + + cb_state = set_callbacks(session, channel); + + rc = ssh_channel_open_session(channel); + assert_ssh_return_code(session, rc); + + assert_int_equal(STATE_SUCCESS, cb_state->open_response); + + ssh_channel_free(channel); +} + +static void torture_open_failure(void **state) +{ + struct torture_state *s = *state; + ssh_session session = s->ssh.session; + ssh_channel channel; + int rc; + struct callback_state *cb_state = NULL; + + channel = ssh_channel_new(session); + assert_non_null(channel); + + cb_state = set_callbacks(session, channel); + + rc = ssh_channel_open_forward(channel, "0.0.0.0", 0, "0.0.0.0", 0); + assert_ssh_return_code_equal(session, rc, SSH_ERROR); + + assert_int_equal(STATE_FAILURE, cb_state->open_response); + + ssh_channel_free(channel); +} + +static void torture_request_success(void **state) +{ + struct torture_state *s = *state; + ssh_session session = s->ssh.session; + ssh_channel channel; + int rc; + struct callback_state *cb_state = NULL; + + channel = ssh_channel_new(session); + assert_non_null(channel); + + cb_state = set_callbacks(session, channel); + + rc = ssh_channel_open_session(channel); + assert_ssh_return_code(session, rc); + + rc = ssh_channel_request_exec(channel, "echo -n ABCD"); + assert_ssh_return_code(session, rc); + + assert_int_equal(STATE_SUCCESS, cb_state->request_response); + + ssh_channel_free(channel); +} + +static void torture_request_failure(void **state) +{ + struct torture_state *s = *state; + ssh_session session = s->ssh.session; + ssh_channel channel; + int rc; + struct callback_state *cb_state = NULL; + + channel = ssh_channel_new(session); + assert_non_null(channel); + + cb_state = set_callbacks(session, channel); + + rc = ssh_channel_open_session(channel); + assert_ssh_return_code(session, rc); + + rc = ssh_channel_request_env(channel, "NOT_ACCEPTED", "VALUE"); + assert_ssh_return_code_equal(session, rc, SSH_ERROR); + + assert_int_equal(STATE_SUCCESS, cb_state->request_response); + + ssh_channel_free(channel); +} + +int torture_run_tests(void) +{ + int rc; + struct CMUnitTest tests[] = { + cmocka_unit_test_setup_teardown(torture_open_success, + session_setup, + session_teardown), + cmocka_unit_test_setup_teardown(torture_open_failure, + session_setup, + session_teardown), + cmocka_unit_test_setup_teardown(torture_request_success, + session_setup, + session_teardown), + cmocka_unit_test_setup_teardown(torture_request_failure, + session_setup, + session_teardown), + }; + + ssh_init(); + + torture_filter_tests(tests); + rc = cmocka_run_group_tests(tests, sshd_setup, sshd_teardown); + + ssh_finalize(); + + return rc; +} -- GitLab