浏览代码

Merge branch 'kbf-29' of kraxor/kbf into develop

Bence Balint 3 年之前
父节点
当前提交
ce917b09b9
共有 7 个文件被更改,包括 227 次插入42 次删除
  1. 0 2
      include/kbf/http/common.h
  2. 78 16
      include/kbf/web_service.h
  3. 10 8
      src/http/common.cpp
  4. 61 9
      src/web_service.cpp
  5. 3 0
      test/test_http.cpp
  6. 74 6
      test/test_web_service.cpp
  7. 1 1
      test/test_wifi.cpp

+ 0 - 2
include/kbf/http/common.h

@@ -36,8 +36,6 @@ namespace kbf::http {
         /**
          * @brief HTTP headers
          *
-         * @warning Parsing headers from the incoming response is not supported in kbf::http::Client (yet).
-         *
          * @note Only the headers listed in acceptedHeaders will be parsed automatically for incoming
          * requests. For parsing custom headers, use readHeader().
          */

+ 78 - 16
include/kbf/web_service.h

@@ -50,8 +50,8 @@ namespace kbf {
          * @tparam T Controller class
          */
         template<class T>
-        void attach() {
-            static_assert(std::is_base_of<Controller, T>::value, "attach(): type must be a Controller");
+        void controller() {
+            static_assert(std::is_base_of<Controller, T>::value, "controller(): type must be a Controller");
 
             Controller *controller = new T();
 
@@ -62,14 +62,38 @@ namespace kbf {
                 handler = Controller::responseSentHandler;
             }
 
-            server.route({kbf::http::Method::GET, controller->path, handleGet, controller, handler});
-            server.route({kbf::http::Method::POST, controller->path, handlePost, controller, handler});
+            auto *arg = new HandlerArg{*this, *controller};
+            server.route({http::Method::GET, controller->path, handleGet, arg, handler});
+            server.route({http::Method::POST, controller->path, handlePost, arg, handler});
         }
 
+        class Middleware;
+
+        /**
+         * @brief Adds a #Middleware to the service.
+         *
+         * @tparam T Middleware class
+         */
+        template<class T>
+        void middleware() {
+            static_assert(std::is_base_of<Middleware, T>::value, "middleware(): type must be a Middleware");
+            middlewares.push_back(new T());
+        }
+
+        /**
+         * @brief Continue execution of the middleware pipeline.
+         *
+         * @note Should be called from middlewares if you want the pipeline to continue.
+         *
+         * @param request
+         * @return response
+         */
+        http::Response next(const http::Request &request);
+
         /**
          * @brief Controller for the web service.
          *
-         * Pass derived classes to WebService::attach().
+         * Pass derived classes to WebService::controller().
          */
         class Controller {
             friend class WebService;
@@ -93,8 +117,8 @@ namespace kbf {
              * @param request
              * @return HTTP response
              */
-            virtual kbf::http::Response get(const kbf::http::Request &request) {
-                return kbf::http::Response("method not allowed", 405);
+            virtual http::Response get(const http::Request &request) {
+                return http::Response("method not allowed", 405);
             }
 
             /**
@@ -103,8 +127,8 @@ namespace kbf {
              * @param request
              * @return HTTP response
              */
-            virtual kbf::http::Response post(const kbf::http::Request &request) {
-                return kbf::http::Response("method not allowed", 405);
+            virtual http::Response post(const http::Request &request) {
+                return http::Response("method not allowed", 405);
             }
 
             /**
@@ -112,21 +136,59 @@ namespace kbf {
              *
              * @param response
              */
-            virtual void onResponseSent(const kbf::http::Response &response) {}
+            virtual void onResponseSent(const http::Response &response) {}
 
         private:
-            static void responseSentHandler(const kbf::http::Response &response, void *data);
+            static void responseSentHandler(const http::Response &response, void *data);
         };
 
-    private:
-        kbf::http::Server server;
-
-        static kbf::http::Response handleGet(const kbf::http::Request &request, void *data);
+        /**
+         * @brief Base class for HTTP middlewares.
+         *
+         * Pass derived classes to #WebService::middleware().
+         *
+         * The #run() method will be called for every request. Implementations should call webService->next() for
+         * the pipeline to continue.
+         *
+         * @note If multiple middlewares are added to a WebService, they will be called in the order they were added.
+         */
+        class Middleware {
+        public:
+            /**
+             * @brief Function to run for every request. Should call webService->next() for the pipeline to continue.
+             *
+             * @param request incoming request
+             * @param webService webService instance
+             * @return response
+             */
+            virtual http::Response run(const http::Request &request, WebService &webService) = 0;
+        };
 
-        static kbf::http::Response handlePost(const kbf::http::Request &request, void *data);
+    private:
+        http::Server server;
 
         bool running = false;
         int  port;
+
+        struct HandlerArg {
+            WebService &webService;
+            Controller &controller;
+        };
+
+        typedef http::Response (WebService::Controller::*MethodFunction)(const http::Request &request);
+
+        MethodFunction currentMethod           = nullptr;
+        Controller     *currentController      = nullptr;
+
+        static http::Response handleGet(const http::Request &request, void *data);
+
+        static http::Response handlePost(const http::Request &request, void *data);
+
+        http::Response startPipeline(const http::Request &request, HandlerArg *arg);
+
+        std::vector<Middleware *> middlewares;
+        std::optional<std::vector<Middleware *>::iterator>
+                                  middlewareIt = std::nullopt;
     };
 }
 

+ 10 - 8
src/http/common.cpp

@@ -9,6 +9,8 @@
 #include "kbf/macros.h"
 
 using namespace kbf;
+using http::Request;
+using http::Response;
 using std::string;
 using std::map;
 using std::vector;
@@ -27,7 +29,7 @@ void http::HTTPObject::parseToMap(map<string, string> &target, const string &buf
     }
 }
 
-http::Request::Request(httpd_req_t *httpdRequest) {
+Request::Request(httpd_req_t *httpdRequest) {
     this->httpdRequest = httpdRequest;
     method = static_cast<Method>(httpdRequest->method);
     uri    = httpdRequest->uri;
@@ -39,7 +41,7 @@ http::Request::Request(httpd_req_t *httpdRequest) {
     readBody();
 }
 
-string http::Request::readHeader(const string &header) const {
+string Request::readHeader(const string &header) const {
     auto len = httpd_req_get_hdr_value_len(httpdRequest, header.c_str());
     if (len == 0) {
         ESP_LOGD(TAG, "  header not found: %s", header.c_str());
@@ -51,11 +53,11 @@ string http::Request::readHeader(const string &header) const {
     return buffer;
 }
 
-void http::Request::readAndStoreHeader(const string &header) {
+void Request::readAndStoreHeader(const string &header) {
     headers[header] = readHeader(header);
 }
 
-void http::Request::readQuery() {
+void Request::readQuery() {
     auto queryLen = httpd_req_get_url_query_len(httpdRequest);
     char buffer[queryLen + 1];
 
@@ -70,7 +72,7 @@ void http::Request::readQuery() {
     parseToMap(query, buffer);
 }
 
-void http::Request::readBody() {
+void Request::readBody() {
     if (httpdRequest->content_len == 0) {
         return;
     }
@@ -101,19 +103,19 @@ void http::Request::readBody() {
     }
 }
 
-http::Response::Response(string body, int status, string contentType) :
+Response::Response(string body, int status, string contentType) :
         HTTPObject(std::move(body)),
         status(status),
         contentType(std::move(contentType)) {
 }
 
-http::Response::Response(const char *body, int status, string contentType) :
+Response::Response(const char *body, int status, string contentType) :
         HTTPObject(body),
         status(status),
         contentType(std::move(contentType)) {
 }
 
-http::Response::Response(const nlohmann::json &json, int status, string contentType) :
+Response::Response(const nlohmann::json &json, int status, string contentType) :
         HTTPObject(json.dump()),
         status(status),
         contentType(std::move(contentType)) {

+ 61 - 9
src/web_service.cpp

@@ -5,17 +5,23 @@
 #include <utility>
 
 using namespace kbf;
+using http::Request;
+using http::Response;
 
 WebService::WebService(int port) : server(), port(port) {
+    ESP_LOGI(TAG, "%s(%d)", __func__, this->port);
 }
 
 WebService::~WebService() {
+    ESP_LOGI(TAG, "%s()", __func__);
     if (running) {
         stop();
     }
 }
 
 void WebService::start() {
+    ESP_LOGI(TAG, "%s()", __func__);
+
     if (running) {
         ABORT("WebService already running");
     }
@@ -25,6 +31,8 @@ void WebService::start() {
 }
 
 void WebService::stop() {
+    ESP_LOGI(TAG, "%s()", __func__);
+
     if (!running) {
         ESP_LOGW(TAG, "stop(): not running");
         return;
@@ -34,19 +42,63 @@ void WebService::stop() {
     running = false;
 }
 
-kbf::http::Response WebService::handleGet(const http::Request &request, void *data) {
-    auto controller = static_cast<Controller *>(data);
-    return controller->get(request);
+Response WebService::startPipeline(const Request &request, HandlerArg *arg) {
+    ESP_LOGD(TAG, "%s()", __func__);
+
+    currentController = &arg->controller;
+    middlewareIt      = std::nullopt;
+
+    ESP_LOGI(TAG, "%s %s",
+             request.method == http::GET ? "GET" : request.method == http::POST ? "POST" : "UNKNOWN",
+             request.uri.c_str());
+
+    return next(request);
 }
 
-kbf::http::Response WebService::handlePost(const http::Request &request, void *data) {
-    auto controller = static_cast<Controller *>(data);
-    return controller->post(request);
+Response WebService::handleGet(const Request &request, void *data) {
+    ESP_LOGD(TAG, "%s()", __func__);
+
+    auto arg = static_cast<HandlerArg *>(data);
+    arg->webService.currentMethod = &Controller::get;
+
+    return arg->webService.startPipeline(request, arg);
+}
+
+Response WebService::handlePost(const Request &request, void *data) {
+    ESP_LOGD(TAG, "%s()", __func__);
+
+    auto arg = static_cast<HandlerArg *>(data);
+    arg->webService.currentMethod = &Controller::post;
+
+    return arg->webService.startPipeline(request, arg);
 }
 
 WebService::Controller::Controller(std::string path) : path(std::move(path)) {}
 
-void WebService::Controller::responseSentHandler(const http::Response &response, void *data) {
-    auto controller = static_cast<Controller *>(data);
-    controller->onResponseSent(response);
+void WebService::Controller::responseSentHandler(const Response &response, void *data) {
+    ESP_LOGD(TAG, "%s()", __func__);
+
+    auto arg = static_cast<HandlerArg *>(data);
+    arg->controller.onResponseSent(response);
+}
+
+Response WebService::next(const Request &request) { // noqa
+    ESP_LOGD(TAG, "%s()", __func__);
+
+    if (middlewareIt == std::nullopt) {
+        middlewareIt = middlewares.begin();
+    } else {
+        (*middlewareIt)++;
+    }
+
+    if (*middlewareIt == middlewares.end()) {
+        ESP_LOGD(TAG, "  running controller");
+        auto result = (currentController->*currentMethod)(request);
+        currentMethod     = nullptr;
+        currentController = nullptr;
+        return result;
+    }
+
+    ESP_LOGD(TAG, "  running middleware %d", *middlewareIt - middlewares.begin());
+    return (**middlewareIt)->run(request, *this);
 }

+ 3 - 0
test/test_http.cpp

@@ -27,17 +27,20 @@ TEST_CASE("HTTP GET, POST, 404, 405", "[kbf_http]") {
 
     http::Response (*handleGet)(const http::Request &, void *) = {[](const http::Request &request, void *) {
         TEST_ASSERT_EQUAL(http::GET, request.method);
+        TEST_ASSERT_EQUAL_STRING("/get-only", request.uri.c_str());
         return http::Response("OK");
     }};
     server.route({http::GET, "/get-only", handleGet, nullptr});
 
     http::Response (*handlePost)(const http::Request &, void *) = {[](const http::Request &request, void *) {
         TEST_ASSERT_EQUAL(http::POST, request.method);
+        TEST_ASSERT_EQUAL_STRING("/post-only", request.uri.c_str());
         return http::Response("OK");
     }};
     server.route({http::POST, "/post-only", handlePost, nullptr});
 
     http::Response (*handleGetAndPost)(const http::Request &, void *) = {[](const http::Request &request, void *) {
+        TEST_ASSERT_EQUAL_STRING("/get-and-post", request.uri.c_str());
         if (request.method == http::GET) {
             return http::Response("GET");
         } else if (request.method == http::POST) {

+ 74 - 6
test/test_web_service.cpp

@@ -37,16 +37,14 @@ public:
 };
 
 TEST_CASE("WebService", "[kbf_web_service]") {
-    using namespace kbf;
-
     wifi::start();
 
     auto webService = WebService();
-    webService.attach<CounterController>();
-    webService.attach<EchoController>();
+    webService.controller<CounterController>();
+    webService.controller<EchoController>();
     webService.start();
 
-    auto client   = http::Client();
+    auto client = http::Client();
 
     auto response = client.get("http://localhost/counter");
     TEST_ASSERT_EQUAL(200, response->status);
@@ -69,4 +67,74 @@ TEST_CASE("WebService", "[kbf_web_service]") {
 
     webService.stop();
     wifi::stop();
-}
+}
+
+class HeaderToParamMiddleware : public WebService::Middleware {
+public:
+    http::Response run(const http::Request &request, WebService &webService) override {
+        string str = request.readHeader("X-Request");
+
+        auto changedRequest = request;
+        if (request.method == http::GET) {
+            changedRequest.query["str"] = str;
+        } else if (request.method == http::POST) {
+            changedRequest.body = R"({"str":")" + str + "\"}";
+        } else {
+            TEST_FAIL();
+        }
+
+        auto response = webService.next(changedRequest);
+        response.headers["X-Response"] = response.body;
+        response.body = "changed";
+
+        return response;
+    }
+};
+
+class ReverseMiddleware : public WebService::Middleware {
+public:
+    http::Response run(const http::Request &request, WebService &webService) override {
+        string str;
+        if (request.method == http::GET) {
+            str = request.query.at("str");
+        } else if (request.method == http::POST) {
+            str = request.json().find("str")->get<string>();
+        } else {
+            TEST_FAIL();
+        }
+
+        std::reverse(str.begin(), str.end());
+
+        auto changedRequest = request;
+        if (request.method == http::GET) {
+            changedRequest.query["str"] = str;
+        } else {
+            changedRequest.body = R"({"str":")" + str + "\"}";
+        }
+
+        return webService.next(changedRequest);
+    }
+};
+
+TEST_CASE("WebService Middleware", "[kbf_web_service]") {
+    wifi::start();
+
+    auto webService = WebService();
+    webService.controller<EchoController>();
+    webService.middleware<HeaderToParamMiddleware>();
+    webService.middleware<ReverseMiddleware>();
+    webService.start();
+
+    auto client = http::Client();
+
+    auto response = client.get("http://localhost/echo", {{{"X-Request", "hello"}}});
+    TEST_ASSERT_EQUAL_STRING("changed", response->body.c_str());
+    TEST_ASSERT_EQUAL_STRING("olleh", response->headers.at("X-Response").c_str());
+
+    response = client.post("http://localhost/echo", nullptr,{{{"X-Request", "foo"}}});
+    TEST_ASSERT_EQUAL_STRING("changed", response->body.c_str());
+    TEST_ASSERT_EQUAL_STRING("oof", response->headers.at("X-Response").c_str());
+
+    webService.stop();
+    wifi::stop();
+}

+ 1 - 1
test/test_wifi.cpp

@@ -372,7 +372,7 @@ void wifiModeSwitchSlave() {
     wifi::start(ap, sta);
 
     auto webService = WebService();
-    webService.attach<TestController>();
+    webService.controller<TestController>();
     webService.start();
 
     sta->connect(KBF_TEST_WIFI_DUAL_MASTER_SSID, KBF_TEST_WIFI_PASS);