Home > Blockchain >  Is there a problem with the socket I wrote for boost::socket, for the simplifying work with network?
Is there a problem with the socket I wrote for boost::socket, for the simplifying work with network?

Time:04-25

I wrote a socket class to wrap all the work with asynchronous methods boost::asio, did it for the sake of code reduction, just inherit from this class and use its methods! Is there any flaws, because there is uncertainty that the implementation is in places with UB or bugs!

#include <boost/asio.hpp>

#include <memory>
#include <string>
#include <utility>

namespace network {
    enum Type {
        UDP,
        TCP
    };

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    struct SocketImpl : public std::enable_shared_from_this<SocketImpl<socket_type, resolver_type, endpoint_iter_type>> {
    public:
        typedef std::function<void()> ConnectCallback, PromoteCallback, PostCallback;
        typedef std::function<void(size_t)> WriteCallback;
        typedef std::function<void(const uint8_t *, size_t)> ReadCallback;
        typedef std::function<void(const std::string &)> ErrorCallback;

        explicit SocketImpl(const boost::asio::strand<typename boost::asio::io_service::executor_type> &executor)
            : socket_(executor), resolver_(executor), timeout_(executor) {}

        explicit SocketImpl(socket_type sock)
            : resolver_(sock.get_executor()), timeout_(sock.get_executor()), socket_(std::move(sock)) {}

        void Post(const PostCallback &callback);

        auto Get() { return this->shared_from_this(); }

        void Connect(std::string Host, std::string Port, const ConnectCallback &connect_callback, const ErrorCallback &error_callback);

        virtual void Send(const uint8_t *message_data, size_t size, const WriteCallback &write_callback, const ErrorCallback &error_callback) = 0;

        virtual void Read(size_t size, const ReadCallback &read_callback, const ErrorCallback &error_callback) = 0;

        template <typename Handler> void Await(boost::posix_time::time_duration ms, Handler f);

        virtual void Disconnect();

        ~SocketImpl();

    protected:
        void stop_await();
        
        virtual void do_resolve(std::string host, std::string port, const SocketImpl::ConnectCallback &connect_callback,
            const SocketImpl::ErrorCallback &error_callback) = 0;

        void deadline();

        resolver_type resolver_;

        endpoint_iter_type endpoint_iter_;

        socket_type socket_;

        boost::asio::deadline_timer timeout_;

        boost::asio::streambuf buff_;
    };

    template <Type t>
    struct Socket
        : public SocketImpl<boost::asio::ip::tcp::socket, boost::asio::ip::tcp::resolver, boost::asio::ip::tcp::resolver::iterator> {
        explicit Socket(const boost::asio::strand<typename boost::asio::io_service::executor_type> &executor) : SocketImpl(executor) {}

        explicit Socket(boost::asio::ip::tcp::socket sock) : SocketImpl(std::move(sock)) {
            if (socket_.is_open())
                is_connected = true;
        }

        void Send(const uint8_t *message_data, size_t size, const WriteCallback &write_callback, const ErrorCallback &error_callback) override {
            auto self = Get();
            Post([this, self, message_data, size, write_callback, error_callback] {
                boost::asio::async_write(socket_, boost::asio::buffer(message_data, size),
                    [this, self, write_callback, error_callback](boost::system::error_code ec, std::size_t bytes_transferred) {
                        if (!ec) {
                            write_callback(bytes_transferred);
                        } else {
#ifdef OS_WIN
                            SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                            error_callback(ec.message());
                        }
                    });
            });
        }

        void Read(size_t size, const ReadCallback &read_callback, const ErrorCallback &error_callback) override {
            auto self = Get();
            Post([this, self, size, read_callback, error_callback] {
                boost::asio::async_read(socket_, boost::asio::buffer(buff_.prepare(size)),
                    [this, self, read_callback, error_callback](boost::system::error_code ec, std::size_t length) {
                        stop_await();
                        if (!ec) {
                            const uint8_t *data = boost::asio::buffer_cast<const uint8_t *>(buff_.data());
                            read_callback(data, length);
                        } else {
#ifdef OS_WIN
                            SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                            error_callback(ec.message());
                        }
                        buff_.consume(length);
                    });
            });
        }

        bool IsConnected() const { return is_connected; }

        void ReadUntil(std::string until_str, const ReadCallback &read_callback, const ErrorCallback &error_callback) {
            auto self = Get();
            Post([this, self, until_str = std::move(until_str), read_callback, error_callback] {
                boost::asio::async_read_until(socket_, buff_, until_str,
                    [this, read_callback, error_callback](boost::system::error_code ec, std::size_t bytes_transferred) {
                        stop_await();
                        if (!ec) {
                            const uint8_t *data = boost::asio::buffer_cast<const uint8_t *>(buff_.data());
                            read_callback(data, bytes_transferred);
                        } else {
#ifdef OS_WIN
                            SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                            error_callback(ec.message());
                        }
                        buff_.consume(bytes_transferred);
                    });
            });
        }

    protected:
        void do_resolve(std::string host, std::string port, const SocketImpl::ConnectCallback &connect_callback,
            const SocketImpl::ErrorCallback &error_callback) override {
            auto self = Get();
            resolver_.async_resolve(host, port,
                [this, self, connect_callback, error_callback](
                    boost::system::error_code ec, boost::asio::ip::tcp::resolver::iterator endpoints) {
                    stop_await();
                    if (!ec) {
                        endpoint_iter_ = std::move(endpoints);
                        do_connect(endpoint_iter_, connect_callback, error_callback);
                    } else {
#ifdef OS_WIN
                        SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                        error_callback("Unable to resolve host: "   ec.message());
                    }
                });
        }

        void do_connect(boost::asio::ip::tcp::resolver::iterator endpoints, const SocketImpl::ConnectCallback &connect_callback,
            const SocketImpl::ErrorCallback &error_callback) {
            auto self = Get();
            boost::asio::async_connect(socket_, std::move(endpoints),
                [this, self, connect_callback, error_callback](
                    boost::system::error_code ec, [[maybe_unused]] const boost::asio::ip::tcp::resolver::iterator &) {
                    stop_await();
                    if (!ec) {
                        connect_callback();
                    } else {
#ifdef OS_WIN
                        SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                        error_callback("Unable to connect host: "   ec.message());
                    }
                });
        }

        bool is_connected = false;
    };

    template <>
    struct Socket<UDP>
        : public SocketImpl<boost::asio::ip::udp::socket, boost::asio::ip::udp::resolver, boost::asio::ip::udp::resolver::iterator> {
    public:
        explicit Socket(const boost::asio::strand<typename boost::asio::io_service::executor_type> &executor) : SocketImpl(executor) {}

        explicit Socket(boost::asio::ip::udp::socket sock) : SocketImpl(std::move(sock)) {}

        void Send(const uint8_t *message_data, size_t size, const WriteCallback &write_callback, const ErrorCallback &error_callback) override {
            auto self = Get();
            Post([this, self, message_data, size, write_callback, error_callback] {
                socket_.async_send_to(boost::asio::buffer(message_data, size), *endpoint_iter_,
                    [this, self, write_callback, error_callback](boost::system::error_code ec, size_t bytes_transferred) {
                        if (!ec) {
                            write_callback(bytes_transferred);
                        } else {
#ifdef OS_WIN
                            SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                            error_callback(ec.message());
                        }
                    });
            });
        }

        void Read(size_t size, const ReadCallback &read_callback, const ErrorCallback &error_callback) override {
            auto self = Get();
            Post([this, self, size, read_callback, error_callback] {
                boost::asio::ip::udp::endpoint endpoint = *endpoint_iter_;
                socket_.async_receive_from(boost::asio::buffer(buff_.prepare(size)), endpoint,
                    [this, self, read_callback, error_callback](boost::system::error_code ec, size_t bytes_transferred) {
                        stop_await();
                        if (!ec) {
                            const auto *data = boost::asio::buffer_cast<const uint8_t *>(buff_.data());
                            read_callback(data, bytes_transferred);
                        } else {
                            error_callback(ec.message());
                        }
                        buff_.consume(bytes_transferred);
                    });
            });
        }

        void Promote(const PromoteCallback &callback);

    protected:
        void do_resolve(std::string host, std::string port, const SocketImpl::ConnectCallback &connect_callback,
            const SocketImpl::ErrorCallback &error_callback) override {
            auto self = Get();
            resolver_.async_resolve(host, port,
                [this, self, connect_callback, error_callback](
                    boost::system::error_code ec, boost::asio::ip::udp::resolver::iterator endpoints) {
                    stop_await();
                    if (!ec) {
                        endpoint_iter_ = std::move(endpoints);
                        boost::asio::ip::udp::endpoint endpoint = *endpoint_iter_;
                        socket_.open(endpoint.protocol());

                        connect_callback();
                    } else {
#ifdef OS_WIN
                        SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                        error_callback("Unable to resolve host: "   ec.message());
                    }
                });
        }
    };

    void Socket<UDP>::Promote(const PromoteCallback &callback) {
        auto self = Get();
        Post([this, self, callback] {
            endpoint_iter_  ;
            socket_.cancel();
            callback();
        });
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    void SocketImpl<socket_type, resolver_type, endpoint_iter_type>::Post(const SocketImpl::PostCallback &callback) {
        post(socket_.get_executor(), callback);
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    void SocketImpl<socket_type, resolver_type, endpoint_iter_type>::Connect(std::string Host, std::string Port,
        const SocketImpl::ConnectCallback &connect_callback, const SocketImpl::ErrorCallback &error_callback) {
        auto self = Get();
        Post([this, self, Host, Port, connect_callback, error_callback] { do_resolve(Host, Port, connect_callback, error_callback); });
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    template <typename Handler>
    void SocketImpl<socket_type, resolver_type, endpoint_iter_type>::Await(boost::posix_time::time_duration ms, Handler f) {
        auto self = Get();
        Post([this, ms, self, f] {
            timeout_.expires_from_now(ms);
            timeout_.template async_wait([this, self, f](boost::system::error_code const &ec) {
                if (!ec) {
                    deadline(f);
                }
            });
        });
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    void SocketImpl<socket_type, resolver_type, endpoint_iter_type>::Disconnect() {
        auto self = Get();
        Post([this, self] {
#ifdef OS_WIN
            SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
            timeout_.cancel();
            resolver_.cancel();
            if (socket_.is_open()) socket_.cancel();
        });
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    void SocketImpl<socket_type, resolver_type, endpoint_iter_type>::stop_await() {
        timeout_.cancel();
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    void SocketImpl<socket_type, resolver_type, endpoint_iter_type>::deadline() {
        if (timeout_.expires_at() <= boost::asio::deadline_timer::traits_type::now()) {
            timeout_.cancel();
            socket_.cancel();
        } else {
            auto self(Get());
            timeout_.async_wait([this, self](boost::system::error_code ec) {
                if (!ec) {
                    deadline();
                }
            });
        }
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    SocketImpl<socket_type, resolver_type, endpoint_iter_type>::~SocketImpl() {
        if (socket_.is_open()) socket_.close();
    }
} // namespace network

I use it like this (C 17):

struct Client : Socket<TCP> { ... };

Happy to take advice on this structure! Thanks!

CodePudding user response:

That's a lot of code.

  1. Always compile with warnings enabled. This would have told you that members are not constructed in the order you list their initializers. Importantly, the second one is UB:

    explicit SocketImpl(socket_type sock)
        : resolver_(sock.get_executor()), timeout_(sock.get_executor()), socket_(std::move(sock)) {}
    

    Because socket_ is declared before timeout_, it will also be initialized before, meaning that sock.get_executor() is actually use-after-move. Oops. Fix it:

    explicit SocketImpl(socket_type sock)
        : resolver_(sock.get_executor()), socket_(std::move(sock)), timeout_(socket_.get_executor()) {}
    

    Now, even though the other constructor doesn't have such a problem, it's good practice to match declaration order there as well:

    explicit SocketImpl(Executor executor)
        : resolver_(executor)
        , socket_(executor)
        , timeout_(executor) {}
    
    explicit SocketImpl(socket_type sock)
        : resolver_(sock.get_executor())
        , socket_(std::move(sock))
        , timeout_(socket_.get_executor()) {}
    

    (Kudos for making constructors explicit)

  2. I'd implement any Impl class inline (the naming suggests that the entire class is "implementation detail" anyways).

  3. Destructors like this are busy-work:

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    SocketImpl<socket_type, resolver_type, endpoint_iter_type>::~SocketImpl() {
        if (socket_.is_open()) {
            socket_.close();
        }
    }
    

    The default destructor of socket_ will already do that. All you do is get in the way of the compiler to generate optimal, exception safe code. E.g. in this case close() might raise an exception. Did you want that?

  4. Consider taking arguments that hold resource by const-reference, or by value if you intend to std::move() from them.

    virtual void do_resolve(std::string host, std::string port,
                            ConnectCallback const&,
                            ErrorCallback const&) = 0;
    
  5. These instantiations:

    template <Type>
    struct Socket
        : public SocketImpl<boost::asio::ip::tcp::socket,
                            boost::asio::ip::tcp::resolver,
                            boost::asio::ip::tcp::resolver::iterator> {
    

    and

    template <>
    struct Socket<UDP>
        : public SocketImpl<boost::asio::ip::udp::socket,
                            boost::asio::ip::udp::resolver,
                            boost::asio::ip::udp::resolver::iterator> {
    

    Seem laborious. Why not use the generic templates and protocols from Asio directly? You could even throw in a free performance optimization by allowing callers to override the type-erased executor type:

    template <typename Protocol,
              typename Executor = boost::asio::any_io_executor>
    struct SocketImpl
        : public std::enable_shared_from_this<SocketImpl<Protocol, Executor>> {
      public:
        using base_type     = SocketImpl<Protocol, Executor>;
        using socket_type   = std::conditional_t<
            std::is_same_v<Protocol, boost::asio::ip::udp>,
            boost::asio::basic_datagram_socket<Protocol, Executor>,
            boost::asio::basic_socket<Protocol, Executor>>;
        using resolver_type =
            boost::asio::ip::basic_resolver<Protocol, Executor>;
        using endpoint_iter_type = typename resolver_type::iterator;
    

    Now your instantiations can just be:

    template <Type> struct Socket : public SocketImpl<boost::asio::ip::tcp> {
        // ...
    template <> struct Socket<UDP> : public SocketImpl<boost::asio::ip::udp> {
    

    with the exact behaviour you had, or better:

    using StrandEx = boost::asio::strand<boost::asio::io_context::executor_type>;
    
    template <Type> struct Socket : public SocketImpl<boost::asio::ip::tcp, StrandEx> {
        // ...
    template <> struct Socket<UDP> : public SocketImpl<boost::asio::ip::udp, StrandEx> {
    

    with the executor optimized for the strand as you were restricting it to anyways!

  6. Instead of repeating the type arguments:

    explicit Socket(boost::asio::ip::tcp::socket sock) : SocketImpl(std::move(sock)) {
    

    refer to exposed typedefs, so you have a single source of truth:

    explicit Socket(base_type::socket_type sock) : SocketImpl(std::move(sock)) {
    
  7. Pass executors by value. They're cheap to copy and you could even move from them since you're "sinking" them into you members

  8. In fact, just inherit the constructors whole-sale instead of repeating. So even:

    template <Type>
    struct Socket : public SocketImpl<boost::asio::ip::tcp, StrandEx> {
        explicit Socket(StrandEx executor) : SocketImpl(executor) {}
        explicit Socket(base_type::socket_type sock)
            : SocketImpl(std::move(sock)) {}
    

    Could just be:

    template <Type>
    struct Socket : public SocketImpl<boost::asio::ip::tcp, StrandEx> {
        using base_type::base_type;
    

    and land you with the exact same set of constructors.

  9. That constructor sets is_connected but it's lying about it. Because it sets it to true when the socket is merely open. You don't want this, nor do you need it.

    In your code, nobody is using that. What you might want in a deriving client, is a state machine. It's up to them. No need to add a racy, lying interface to your base class. Leave the responsibility where it belongs.

  10. Same with this:

     #ifdef OS_WIN
         SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
     #endif
    

    That's a violation of separation of concerns. You might want this behaviour, but your callers/users might want something else. Worse, this behaviour may break their code that had a different preference in place.

  11. Get() does nothing but obscure that it returns shared_from_this. If it is there to avoid explicitly qualifying with this-> (because the base class is a dependent type), just, again, use a using declaration:

    using std::enable_shared_from_this<SocketImpl>::shared_from_this;
    
  12. There's a big problem with PostCallback being std::function. It hides associated executor types! See enter image description here

  • Related