MirrorYuChen
MirrorYuChen
Published on 2025-03-23 / 14 Visits
0
0

Workflow学习笔记(二):编程范式学习笔记

Workflow学习笔记(二):编程范式学习笔记

​ 搜狗公司C++服务器引擎,编程范式。支撑搜狗几乎所有后端C++在线服务,包括所有搜索服务,云输入法,在线广告等,每日处理数百亿请求。这是一个设计轻盈优雅的企业级程序引擎,可以满足大多数后端与嵌入式开发需求。

1.编程范式

程序 = 协议 + 算法 + 任务流

  • 协议
    • 大多数情况下,用户使用的是内置的通用网络协议,例如 httpredis或各种 rpc
    • 用户可以方便的自定义网络协议,只需提供序列化和反序列化函数,就可以定义出自己的client/server。
  • 算法
    • 在设计里,算法是与协议对称的概念,如果说协议的调用是 rpc,算法的调用就是一次 apcAsync Procedure Call)。
    • 提供了一些通用算法,例如 sortmergepsortreduce,可以直接使用。
    • 与自定义协议相比,自定义算法的使用要常见得多。任何一次边界清晰的复杂计算,都应该包装成算法。
  • 任务流
    • 任务流就是实际的业务逻辑,就是把开发好的协议与算法放在流程图里使用起来。
    • 典型的任务流是一个闭合的串并联图。复杂的业务逻辑,可能是一个非闭合的DAG。
    • 任务流图可以直接构建,也可以根据每一步的结果动态生成。所有任务都是异步执行的。

2.结构化并发与任务隐藏

  • 系统中包含五种基础任务:通讯,计算,文件IO,定时器,计数器。
  • 一切任务都由任务工厂产生,用户通过调用接口组织并发结构。例如串联并联,DAG等。
  • 大多数情况下,用户通过任务工厂产生的任务,都隐藏了多个异步过程,但用户并不感知。
    • 例如,一次 http请求,可能包含许多次异步过程(DNS,重定向),但对用户来讲,就是一次通信任务。
    • 文件排序,看起来就是一个算法,但其实包括复杂的文件IO与CPU计算的交互过程。
    • 如果把业务逻辑想象成用设计好的电子元件搭建电路,那么每个电子元件内部可能又是一个复杂电路。
    • 任务隐藏机制大幅减少了用户需要创建的任务数量和回调深度。
  • 任何任务都运行在某个串行流(series)里,共享series上下文,让异步任务之间数据传递变得简单。

3.回调与内存回收机制

  • 一切调用都是异步执行,几乎不存在占着线程等待的操作。
  • 显式的回调机制,用户清楚自己在写异步程序。
  • 通过一套对象生命周期机制,大幅简化异步程序内存管理
    • 任何框架创建的任务,生命周期都是从创建到callback函数运行结束为止,没有泄漏风险,如果创建了任务之后不想运行,则需要通过dismiss()接口删除。
    • 任务中的数据,例如网络请求的resp,也会随着任务被回收,此时用户可通过 std::move()把需要的数据移走。
    • 项目中不使用任何智能指针来管理内存,代码观感清新。
  • 尽量避免用户级别派生,以 std::function封装用户行为,包括:
    • 任何任务的callback。
    • 任何server的process,符合 FaaS(Function as a Service)思想。
    • 一个算法的实现,简单来讲也是一个 std::function
    • 如果深入使用,又会发现一切皆可派生。

4.基于Workflow的AI服务部署

​ 这里主要参考项目mortred_model_server,基于workflow的服务部署部分核心代码位于 src/server/abstract_server.hsrc/server/base_server_impl.h两个文件:

/************************************************
* Copyright MaybeShewill-CV. All Rights Reserved.
* Author: MaybeShewill-CV
* File: base_server.h
* Date: 22-6-21
************************************************/

#ifndef MORTRED_MODEL_SERVER_BASESERVER_H
#define MORTRED_MODEL_SERVER_BASESERVER_H

#include <toml/toml.hpp>
#include <workflow/WFTask.h>
#include <workflow/WFHttpServer.h>

#include "common/status_code.h"

namespace jinq {
namespace server {
class BaseAiServer {
public:
    /***
    *
    */
    virtual ~BaseAiServer() = default;

    /***
     * 构造函数
     * @param config
     */
    BaseAiServer() = default;

    /***
     *
     * @param cfg
     * @return
     */
    virtual jinq::common::StatusCode init(const decltype(toml::parse(""))& cfg) = 0;

    /***
     *
     * @param input
     * @param output
     * @return
     */
    virtual void serve_process(WFHttpTask* task) = 0;

    /***
     *
     * @return
     */
    virtual bool is_successfully_initialized() const = 0;

    /***
 *
 * @param port
 * @return
 */
    inline int start(unsigned short port) {
        return _m_server->start(port);
    };

    /***
     *
     * @param host
     * @param port
     * @return
     */
    inline int start(const char *host, unsigned short port) {
        return _m_server->start(host, port);
    };

    /***
     *
     */
    inline void stop() {
        return _m_server->stop();
    };

    /***
     *
     */
    inline void shutdown() {
        _m_server->shutdown();
    };

    /***
     *
     */
    inline void wait_finish() {
        _m_server->wait_finish();
    }

protected:
    std::unique_ptr<WFHttpServer> _m_server;
};
}
}

#endif //MORTRED_MODEL_SERVER_BASESERVER_H

​ 核心接口为纯虚函数 init,用于初始化服务,之所以设置为纯虚函数接口的原因在于,不同的服务对应于不同的Worker,例如,图像分类任务对应于Classifier,目标检测任务对应于Detector,不同任务在初始化阶段需要绑定不同实例,并完成初始化。

​ 对应于实现代码如下:

/************************************************
* Copyright MaybeShewill-CV. All Rights Reserved.
* Author: MaybeShewill-CV
* File: base_server_impl.h
* Date: 22-6-30
************************************************/

#ifndef MORTRED_MODEL_SERVER_BASE_SERVER_IMPL_H
#define MORTRED_MODEL_SERVER_BASE_SERVER_IMPL_H

#include "glog/logging.h"
#include "toml/toml.hpp"
#include "stl_container/concurrentqueue.h"
#include "rapidjson/document.h"
#include "rapidjson/stringbuffer.h"
#include "rapidjson/writer.h"
#include "workflow/HttpMessage.h"
#include "workflow/HttpUtil.h"
#include "workflow/WFTaskFactory.h"
#include "workflow/WFHttpServer.h"
#include "workflow/Workflow.h"

#include "common/md5.h"
#include "common/base64.h"
#include "common/cv_utils.h"
#include "common/status_code.h"
#include "common/time_stamp.h"
#include "common/file_path_util.h"
#include "models/model_io_define.h"

namespace jinq {
namespace server {
using jinq::common::Base64;
using jinq::common::CvUtils;
using jinq::common::FilePathUtil;
using jinq::common::Md5;
using jinq::common::StatusCode;
using jinq::common::Timestamp;

template<typename WORKER, typename MODEL_OUTPUT>
class BaseAiServerImpl {
public:
    /***
    *
    */
    virtual ~BaseAiServerImpl() = default;

    /***
     *
     * @param config
     */
    BaseAiServerImpl() = default;

    /***
    *
    * @param transformer
    */
    BaseAiServerImpl(const BaseAiServerImpl& BaseAiServerImpl) = default;

    /***
     *
     * @param transformer
     * @return
     */
    BaseAiServerImpl& operator=(const BaseAiServerImpl& transformer) = default;

    /***
     *
     * @param cfg
     * @return
     */
    virtual StatusCode init(const decltype(toml::parse(""))& cfg) = 0;

    /***
    *
    * @param task
    */
    virtual void serve_process(WFHttpTask* task);

    /***
     *
     * @return
     */
    virtual bool is_successfully_initialized() const {
        return _m_successfully_initialized;
    };

public:
    int max_connection_nums = 200;
    int peer_resp_timeout = 15 * 1000;
    int compute_threads = -1;
    int handler_threads = 50;
    size_t request_size_limit = -1;

protected:
    // init flag
    bool _m_successfully_initialized = false;
    // task count
    std::atomic<size_t> _m_received_jobs{0};
    std::atomic<size_t> _m_finished_jobs{0};
    std::atomic<size_t> _m_waiting_jobs{0};
    // worker queue
    moodycamel::ConcurrentQueue<WORKER> _m_working_queue;
    // model run timeout
    int _m_model_run_timeout = 500; // ms
    // server uri
    std::string _m_server_uri;

protected:
    struct seriex_ctx {
        protocol::HttpResponse* response = nullptr;
        StatusCode model_run_status = StatusCode::OK;
        std::string task_id;
        std::string task_received_ts;
        std::string task_finished_ts;
        bool is_task_req_valid = false;
        double worker_run_time_consuming = 0; // ms
        double find_worker_time_consuming = 0; // ms
        MODEL_OUTPUT model_output;
    };

    struct cls_request {
        std::string image_content;
        std::string task_id;
        bool is_valid = true;
    };

protected:
    /***
     *
     * @param req_body
     * @return
     */
     virtual cls_request parse_task_request(const std::string& req_body) {

        rapidjson::Document doc;
        doc.Parse(req_body.c_str());
        cls_request req{};

        if (doc.HasParseError() || doc.IsNull() || doc.ObjectEmpty()) {
            req.image_content = "";
            req.is_valid = false;
        } else {
            CHECK_EQ(doc.IsObject(), true);
            if (!doc.HasMember("img_data") || !doc["img_data"].IsString()) {
                req.image_content = "";
                req.is_valid = false;
            } else {
                req.image_content = doc["img_data"].GetString();
                req.is_valid = true;
            }

            if (!doc.HasMember("req_id") || !doc["req_id"].IsString()) {
                req.task_id = "";
                req.is_valid = false;
            } else {
                req.task_id = doc["req_id"].GetString();
            }
        }

        return req;
    };

    /***
     *
     * @param task_id
     * @param status
     * @param model_output
     * @return
     */
    virtual std::string make_response_body(
        const std::string& task_id,
        const StatusCode& status,
        const MODEL_OUTPUT& model_output) = 0;

    /***
     *
     * @param req
     * @param ctx
     */
    virtual void do_work(const cls_request& req, seriex_ctx* ctx);

    /***
     *
     * @param task
     */
    virtual void do_work_cb(const WFGoTask* task);
};

/*********** Public Func Sets **************/

/***
 *
 * @tparam WORKER
 * @tparam MODEL_INPUT
 * @tparam MODEL_OUTPUT
 * @param task
 */
template<typename WORKER, typename MODEL_OUTPUT>
void BaseAiServerImpl<WORKER, MODEL_OUTPUT>::serve_process(WFHttpTask* task) {
    // welcome message
    if (strcmp(task->get_req()->get_request_uri(), "/welcome") == 0) {
        task->get_resp()->append_output_body("<html>Welcome to jinq ai server</html>");
        return;
    }
    // hello world message
    else if (strcmp(task->get_req()->get_request_uri(), "/hello_world") == 0) {
        task->get_resp()->append_output_body("<html>Hello World !!!</html>");
        return;
    }
    // model service
    else if (strcmp(task->get_req()->get_request_uri(), _m_server_uri.c_str()) == 0) {
        // parse request body
        auto* req = task->get_req();
        auto* resp = task->get_resp();
        auto cls_task_req = parse_task_request(protocol::HttpUtil::decode_chunked_body(req));
        _m_waiting_jobs++;
        _m_received_jobs++;
        // init series work
        auto* series = series_of(task);
        auto* ctx = new seriex_ctx;
        ctx->response = resp;
        series->set_context(ctx);
        // do model work
        auto&& go_proc = std::bind(&BaseAiServerImpl<WORKER, MODEL_OUTPUT>::do_work, this, std::placeholders::_1, std::placeholders::_2);
        WFGoTask* serve_task = nullptr;
        if (_m_model_run_timeout <= 0) {
            serve_task = WFTaskFactory::create_go_task(_m_server_uri, go_proc, cls_task_req, ctx);
        } else {
            serve_task = WFTaskFactory::create_timedgo_task(
                0, _m_model_run_timeout * 1e6, _m_server_uri, go_proc, cls_task_req, ctx);
        }
        auto&& go_proc_cb = std::bind(&BaseAiServerImpl<WORKER, MODEL_OUTPUT>::do_work_cb, this, serve_task);
        serve_task->set_callback(go_proc_cb);
        *series << serve_task;
        WFCounterTask* counter = WFTaskFactory::create_counter_task("release_ctx", 1, [](const WFCounterTask* task){
            delete (seriex_ctx*)series_of(task)->get_context();
        });
        *series << counter;
        return;
    }
    // not found valid url
    else {
        task->get_resp()->append_output_body("<html>404 Not Found</html>");
        return;
    }
}

/***
 *
 * @tparam WORKER
 * @tparam MODEL_INPUT
 * @tparam MODEL_OUTPUT
 * @param req
 * @param ctx
 */
template<typename WORKER, typename MODEL_OUTPUT>
void BaseAiServerImpl<WORKER, MODEL_OUTPUT>::do_work(
    const BaseAiServerImpl::cls_request& req,
    BaseAiServerImpl::seriex_ctx* ctx) {
    // get model worker
    WORKER worker;
    auto find_worker_start_ts = Timestamp::now();
    while (!_m_working_queue.try_dequeue(worker)) {}
    ctx->find_worker_time_consuming = (Timestamp::now() - find_worker_start_ts) * 1000;

    // get task receive timestamp
    ctx->task_id = req.task_id;
    ctx->is_task_req_valid = req.is_valid;
    auto task_receive_ts = Timestamp::now();
    ctx->task_received_ts = task_receive_ts.to_format_str();

    // construct model input
    models::io_define::common_io::base64_input model_input{req.image_content};

    // do model inference
    StatusCode status;
    if (req.is_valid) {
        status = worker->run(model_input, ctx->model_output);
        if (status != StatusCode::OK) {
            LOG(ERROR) << "worker run failed";
        }
    } else {
        status = StatusCode::MODEL_EMPTY_INPUT_IMAGE;
    }
    ctx->model_run_status = status;

    // restore worker queue
    while (!_m_working_queue.enqueue(std::move(worker))) {}

    // update ctx
    auto task_finish_ts = Timestamp::now();
    ctx->task_finished_ts = task_finish_ts.to_format_str();
    ctx->worker_run_time_consuming = (task_finish_ts - task_receive_ts) * 1000;
    WFTaskFactory::count_by_name("release_ctx");
}

/***
 *
 * @tparam WORKER
 * @tparam MODEL_INPUT
 * @tparam MODEL_OUTPUT
 * @param task
 */
template<typename WORKER, typename MODEL_OUTPUT>
void BaseAiServerImpl<WORKER, MODEL_OUTPUT>::do_work_cb(const WFGoTask* task) {
    auto state = task->get_state();
    auto* ctx = (seriex_ctx*)series_of(task)->get_context();

    // fill response
    StatusCode status;

    if (state != WFT_STATE_SUCCESS) {
        LOG(ERROR) << "task: " << ctx->task_id << " model run timeout";
        status = StatusCode::MODEL_RUN_TIMEOUT;
    } else {
        status = ctx->model_run_status;
    }

    std::string task_id = ctx->is_task_req_valid ? ctx->task_id : "";
    std::string response_body = make_response_body(task_id, status, ctx->model_output);
    ctx->response->append_output_body(std::move(response_body));

    // update task count
    _m_finished_jobs++;
    _m_waiting_jobs--;

    // output log info
    LOG(INFO) << "task id: " << task_id
              << " received at: " << ctx->task_received_ts
              << " finished at: " << ctx->task_finished_ts
              << " elapse: " << ctx->worker_run_time_consuming << " ms"
              << " find work elapse: " << ctx->find_worker_time_consuming << " ms"
              << " received jobs: " << _m_received_jobs
              << " waiting jobs: " << _m_waiting_jobs
              << " finished jobs: " << _m_finished_jobs
              << " worker queue size: " << _m_working_queue.size_approx();
    // WFTaskFactory::count_by_name("release_ctx");
}
}
}


#endif //MORTRED_MODEL_SERVER_BASE_SERVER_IMPL_H

​ 这里核心代码为 serve_process,这里首先根据请求的uri进行path路由,访问不同的接口,对于模型服务部分依次进行如下操作:

  • (1) 从请求中解析出请求对应数据结构;
  • (2) 设置串行任务流的上下文信息(存储输入和输出);
  • (3) 创建go_task任务,绑定do_work作为该任务的回调函数,并添加到串行任务流中;
  • (4) 创建计数任务以释放上下文信息的资源,并添加到串行任务流中;

​ 对于 do_work接口,依次进行如下操作:

  • (1) 从worker队列中取出一个worker;
  • (2) 使用当前worker来处理当前任务;
  • (3) 将worker存回worker队列;
  • (4) 更新计数器计数(WFTaskFactory::count_by_name("release_ctx"));

​ 这里需要理解一下,为什么计数器能清空上下文资源?串行任务执行完毕后,所有的资源都会清空,计数器会在每次执行完任务后清零,所以每次执行任务开始时,计数器为0,do_work后计数器的计数变为1,因此会调用释放资源操作。

​ 最后,我们来看一个分类服务,分析一下整体是如何串起来的,这里以 src/server/classification/densenet_server.cpp为例:

/************************************************
* Copyright MaybeShewill-CV. All Rights Reserved.
* Author: MaybeShewill-CV
* File: densenet_server.cpp
* Date: 22-7-1
************************************************/

#include "densenet_server.h"

#include "glog/logging.h"
#include "toml/toml.hpp"
#include "rapidjson/stringbuffer.h"
#include "rapidjson/writer.h"
#include "workflow/WFTaskFactory.h"
#include "workflow/WFHttpServer.h"

#include "common/status_code.h"
#include "common/file_path_util.h"
#include "models/model_io_define.h"
#include "server/base_server_impl.h"
#include "factory/classification_task.h"

namespace jinq {
namespace server {

using jinq::common::FilePathUtil;
using jinq::common::StatusCode;
using jinq::server::BaseAiServerImpl;

namespace classification {

using jinq::factory::classification::create_densenet_classifier;
using jinq::models::io_define::common_io::base64_input;
using jinq::models::io_define::classification::std_classification_output;
using DenseNetPtr = decltype(create_densenet_classifier<base64_input, std_classification_output>(""));

/************ Impl Declaration ************/

class DenseNetServer::Impl : public BaseAiServerImpl<DenseNetPtr, std_classification_output> {
public:
    /***
    *
    * @param cfg_file_path
    * @return
    */
    StatusCode init(const decltype(toml::parse(""))& config) override;

protected:
    /***
     *
     * @param task_id
     * @param status
     * @param model_output
     * @return
     */
    std::string make_response_body(
        const std::string& task_id,
        const StatusCode& status,
        const std_classification_output& model_output) override;
};

/************ Impl Implementation ************/

/***
 *
 * @param config
 * @return
 */
StatusCode DenseNetServer::Impl::init(const decltype(toml::parse("")) &config) {
    // init working queue
    auto server_section = config.at("DENSENET_CLASSIFICATION_SERVER");
    auto worker_nums = static_cast<int>(server_section.at("worker_nums").as_integer());
    auto model_section = config.at("DENSENET");
    auto model_cfg_path = model_section.at("model_config_file_path").as_string();

    if (!FilePathUtil::is_file_exist(model_cfg_path)) {
        LOG(FATAL) << "densenet model config file not exist: " << model_cfg_path;
        _m_successfully_initialized = false;
        return StatusCode::SERVER_INIT_FAILED;
    }

    auto model_cfg = toml::parse(model_cfg_path);
    for (int index = 0; index < worker_nums; ++index) {
        auto worker = create_densenet_classifier<base64_input, std_classification_output>(
                          "worker_" + std::to_string(index + 1));
        if (!worker->is_successfully_initialized()) {
            if (worker->init(model_cfg) != StatusCode::OK) {
                _m_successfully_initialized = false;
                return StatusCode::SERVER_INIT_FAILED;
            }
        }

        _m_working_queue.enqueue(std::move(worker));
    }

    // init worker run timeout
    if (!server_section.contains("model_run_timeout")) {
        _m_model_run_timeout = 500; // ms
    } else {
        _m_model_run_timeout = static_cast<int>(server_section.at("model_run_timeout").as_integer());
    }

    // init server uri
    if (!server_section.contains("server_url")) {
        LOG(ERROR) << "missing server uri field";
        _m_successfully_initialized = false;
        return StatusCode::SERVER_INIT_FAILED;
    } else {
        _m_server_uri = server_section.at("server_url").as_string();
    }

    // init server params
    max_connection_nums = static_cast<int>(server_section.at("max_connections").as_integer());
    peer_resp_timeout = static_cast<int>(server_section.at("peer_resp_timeout").as_integer()) * 1000;
    compute_threads = static_cast<int>(server_section.at("compute_threads").as_integer());
    handler_threads = static_cast<int>(server_section.at("handler_threads").as_integer());
    request_size_limit = static_cast<size_t>(server_section.at("request_size_limit").as_integer());

    _m_successfully_initialized = true;
    LOG(INFO) << "densenet classification server init successfully";
    return StatusCode::OK;
}

/***
 *
 * @param task_id
 * @param status
 * @param model_output
 * @return
 */
std::string DenseNetServer::Impl::make_response_body(
    const std::string& task_id,
    const StatusCode& status,
    const std_classification_output& model_output) {
    int code = static_cast<int>(status);
    std::string msg = status == StatusCode::OK ? "success" : jinq::common::error_code_to_str(code);
    int cls_id = -1;
    float scores = -1.0;

    if (status == StatusCode::OK) {
        cls_id = model_output.class_id;
        scores = model_output.scores[cls_id];
    }

    rapidjson::StringBuffer buf;
    rapidjson::Writer<rapidjson::StringBuffer> writer(buf);
    writer.StartObject();
    // write req id
    writer.Key("req_id");
    writer.String(task_id.c_str());
    // write code
    writer.Key("code");
    writer.Int(code);
    // write msg
    writer.Key("msg");
    writer.String(msg.c_str());
    // write class result
    writer.Key("data");
    writer.StartObject();
    writer.Key("class_id");
    writer.Int(cls_id);
    writer.Key("scores");
    writer.Double(scores);
    writer.EndObject();
    writer.EndObject();

    return buf.GetString();
}

/***
 *
 */
DenseNetServer::DenseNetServer() {
    _m_impl = std::make_unique<Impl>();
}

/***
 *
 */
DenseNetServer::~DenseNetServer() = default;

/***
 *
 * @param cfg
 * @return
 */
jinq::common::StatusCode DenseNetServer::init(const decltype(toml::parse("")) &config) {
    // init impl
    auto status = _m_impl->init(config);
    if (status != StatusCode::OK) {
        LOG(INFO) << "init densenet classification server failed";
        return status;
    }

    // init server
    WFGlobalSettings settings = GLOBAL_SETTINGS_DEFAULT;
    settings.compute_threads = _m_impl->compute_threads;
    settings.handler_threads = _m_impl->handler_threads;
    WORKFLOW_library_init(&settings);

    WFServerParams server_params = SERVER_PARAMS_DEFAULT;
    server_params.max_connections = _m_impl->max_connection_nums;
    server_params.peer_response_timeout = _m_impl->peer_resp_timeout;
    server_params.request_size_limit = _m_impl->request_size_limit * 1024 * 1024;

    auto&& proc = [&](auto arg) { return this->_m_impl->serve_process(arg); };
    _m_server = std::make_unique<WFHttpServer>(&server_params, proc);

    return StatusCode::OK;
}

/***
 *
 * @param task
 */
void DenseNetServer::serve_process(WFHttpTask* task) {
    return _m_impl->serve_process(task);
}

/***
 *
 * @return
 */
bool DenseNetServer::is_successfully_initialized() const {
    return _m_impl->is_successfully_initialized();
}
}
}
}

​ DenseNetServer::init做如下工作:

  • (1) 调用实现的初始化接口(_m_impl->init(config))
  • (2) 设置服务器全局配置,包括配置计算线程数、处理线程数;
  • (3) 设置服务器参数,包括最大连接数、对端响应超时和请求大小限制等;
  • (4) 绑定服务的回调函数;

​ DenseNetServer::Impl::init做如下操作:

  • (1) 初始化每一个worker,并添加到worker队列中;
  • (2) 解析模型服务对应uri;
  • (3) 初始化服务参数:最大连接数、对端响应超时、计算线程数、处理线程数和请求大小限制;
  • (4) 设置初始化成功标志;

5.参考资料


Comment