commit 37865d041f18999bec0400b5dd09e7ab48b3ad85 Author: 凝望 <2050965275@qq.com> Date: Mon Oct 13 18:34:48 2025 +0800 update diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4233025 --- /dev/null +++ b/.gitignore @@ -0,0 +1,173 @@ +# ================================ +# 构建系统和编译产物 +# ================================ +# CMake 构建目录 +*/build/ +CMakeCache.txt +CMakeFiles/ +cmake_install.cmake +Makefile +*.cmake +CMakeOutput.log +CMakeRuleHashes.txt +cmake.check_cache +CMakeDirectoryInformation.cmake +CMakeScratch/ +CMakeScripts/ +*.cmake.* +pkgRedirects/ +progress.marks +TargetDirectories.txt + +# 编译产物 +*.o +*.o.d +*.obj +*.exe +*.out +*.app + +# 协议缓冲区生成的文件 +proto/generated/ + +# ODB 生成的数据库文件 +*-odb.cxx +*-odb.hxx +*-odb.ixx +*.sql + +# ================================ +# 可执行文件和二进制文件 +# ================================ +# 服务端可执行文件 +*/file_server +*/file_client +*/friend_server +*/friend_client +*/gateway_server +*/message_server +*/speech_server +*/speech_client +*/transmite_server +*/transmite_client +*/trans_user_client +*/user_server +*/user_client + +# 测试可执行文件 +*/test/*/main +*/test/mysql_test/main +*/test/es_test/main +*/test/redis_test/main + +# 其他二进制文件 +*/nc +*/make_file_download +*/base_download_file1 +*/file_download_file2 + +# ================================ +# 运行时数据和日志 +# ================================ +# 中间件数据 +middle/data/ +middle/elasticsearch/ +middle/etcd/ +middle/logs/ +middle/mysql/ +middle/rabbitmq/ +middle/redis/ + +# 文件服务数据 +file/build/data/ + +# 语音服务数据 +speech/build/16k.pcm +speech/test/16k.pcm + +# 日志文件 +*.log +*.log.* +log/ +logs/ + +# ================================ +# 依赖库文件 +# ================================ +# 本地依赖库 +*/depends/ +*.so +*.so.* +*.dylib +*.dll + +# ================================ +# 临时文件和缓存 +# ================================ +*.tmp +*.temp +.tmp/ +.cache/ +*.swp +*.swo +*~ + +# CMake 临时文件 +*/CMakeFiles/*/CompilerIdC/a.out +*/CMakeFiles/*/CompilerIdCXX/a.out +*/CMakeFiles/*/CompilerIdC/tmp/ +*/CMakeFiles/*/CompilerIdCXX/tmp/ + +# ================================ +# 系统文件 +# ================================ +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db +Desktop.ini + +# ================================ +# IDE 和编辑器文件 +# ================================ +# VS Code +.vscode/ +*.code-workspace + +# Vim +*.swp +*.swo +*~ + +# ================================ +# 协议缓冲区源文件(保留 .proto,忽略生成的 .pb.*) +# ================================ +# 注意:保留 .proto 文件,但忽略生成的文件 +*.pb.cc +*.pb.h + +# ================================ +# 项目特定文件 +# ================================ +# 树状结构输出文件 +tree.txt + +# 下载的文件 +base_download_file1 +file_download_file2 + +# 测试目录中的构建文件 +*/test/*/Makefile +*/test/*/*.o +*/test/*/*.o.d + +# ================================ +# Docker 相关(可选) +# ================================ +# 如果您使用 Docker,可以取消注释以下行 +# .dockerignore +# !.dockerignore +# docker-compose.override.yml \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100755 index 0000000..f7d466f --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,10 @@ +cmake_minimum_required(VERSION 3.1.3) +project(all-test) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/message) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/user) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/file) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/speech) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/transmite) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/friend) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/gateway) +set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}) \ No newline at end of file diff --git a/common/asr.hpp b/common/asr.hpp new file mode 100644 index 0000000..3e23cef --- /dev/null +++ b/common/asr.hpp @@ -0,0 +1,25 @@ +#pragma once +#include "aip-cpp-sdk/speech.h" +#include "logger.hpp" + +namespace bite_im{ +class ASRClient { + public: + using ptr = std::shared_ptr; + ASRClient(const std::string &app_id, + const std::string &api_key, + const std::string &secret_key): + _client(app_id, api_key, secret_key) {} + std::string recognize(const std::string &speech_data, std::string &err){ + Json::Value result = _client.recognize(speech_data, "pcm", 16000, aip::null); + if (result["err_no"].asInt() != 0) { + LOG_ERROR("语音识别失败:{}", result["err_msg"].asString()); + err = result["err_msg"].asString(); + return std::string(); + } + return result["result"][0].asString(); + } + private: + aip::Speech _client; +}; +} \ No newline at end of file diff --git a/common/channel.hpp b/common/channel.hpp new file mode 100644 index 0000000..7c5ed4a --- /dev/null +++ b/common/channel.hpp @@ -0,0 +1,151 @@ +#pragma once +#include +#include +#include +#include +#include +#include "logger.hpp" + +namespace bite_im{ +//1. 封装单个服务的信道管理类: +class ServiceChannel { + public: + using ptr = std::shared_ptr; + using ChannelPtr = std::shared_ptr; + ServiceChannel(const std::string &name): + _service_name(name), _index(0){} + //服务上线了一个节点,则调用append新增信道 + void append(const std::string &host) { + auto channel = std::make_shared(); + brpc::ChannelOptions options; + options.connect_timeout_ms = -1; + options.timeout_ms = -1; + options.max_retry = 3; + options.protocol = "baidu_std"; + int ret = channel->Init(host.c_str(), &options); + if (ret == -1) { + LOG_ERROR("初始化{}-{}信道失败!", _service_name, host); + return; + } + std::unique_lock lock(_mutex); + _hosts.insert(std::make_pair(host, channel)); + _channels.push_back(channel); + } + //服务下线了一个节点,则调用remove释放信道 + void remove(const std::string &host) { + std::unique_lock lock(_mutex); + auto it = _hosts.find(host); + if (it == _hosts.end()) { + LOG_WARN("{}-{}节点删除信道时,没有找到信道信息!", _service_name, host); + return; + } + for (auto vit = _channels.begin(); vit != _channels.end(); ++vit) { + if (*vit == it->second) { + _channels.erase(vit); + break; + } + } + _hosts.erase(it); + } + //通过RR轮转策略,获取一个Channel用于发起对应服务的Rpc调用 + ChannelPtr choose() { + std::unique_lock lock(_mutex); + if (_channels.size() == 0) { + LOG_ERROR("当前没有能够提供 {} 服务的节点!", _service_name); + return ChannelPtr(); + } + int32_t idx = _index++ % _channels.size(); + return _channels[idx]; + } + private: + std::mutex _mutex; + int32_t _index; //当前轮转下标计数器 + std::string _service_name;//服务名称 + std::vector _channels; //当前服务对应的信道集合 + std::unordered_map _hosts; //主机地址与信道映射关系 +}; + +//总体的服务信道管理类 +class ServiceManager { + public: + using ptr = std::shared_ptr; + ServiceManager() {} + //获取指定服务的节点信道 + ServiceChannel::ChannelPtr choose(const std::string &service_name) { + std::unique_lock lock(_mutex); + auto sit = _services.find(service_name); + if (sit == _services.end()) { + LOG_ERROR("当前没有能够提供 {} 服务的节点!", service_name); + return ServiceChannel::ChannelPtr(); + } + return sit->second->choose(); + } + + //先声明,我关注哪些服务的上下线,不关心的就不需要管理了 + void declared(const std::string &service_name) { + std::unique_lock lock(_mutex); + _follow_services.insert(service_name); + } + + //服务上线时调用的回调接口,将服务节点管理起来 + void onServiceOnline(const std::string &service_instance, const std::string &host) { + std::string service_name = getServiceName(service_instance); + ServiceChannel::ptr service; + { + std::unique_lock lock(_mutex); + auto fit = _follow_services.find(service_name); + if (fit == _follow_services.end()) { + LOG_DEBUG("{}-{} 服务上线了,但是当前并不关心!", service_name, host); + return; + } + //先获取管理对象,没有则创建,有则添加节点 + auto sit = _services.find(service_name); + if (sit == _services.end()) { + service = std::make_shared(service_name); + _services.insert(std::make_pair(service_name, service)); + }else { + service = sit->second; + } + } + if (!service) { + LOG_ERROR("新增 {} 服务管理节点失败!", service_name); + return ; + } + service->append(host); + LOG_DEBUG("{}-{} 服务上线新节点,进行添加管理!", service_name, host); + } + + //服务下线时调用的回调接口,从服务信道管理中,删除指定节点信道 + void onServiceOffline(const std::string &service_instance, const std::string &host) { + std::string service_name = getServiceName(service_instance); + ServiceChannel::ptr service; + { + std::unique_lock lock(_mutex); + auto fit = _follow_services.find(service_name); + if (fit == _follow_services.end()) { + LOG_DEBUG("{}-{} 服务下线了,但是当前并不关心!", service_name, host); + return; + } + //先获取管理对象,没有则创建,有则添加节点 + auto sit = _services.find(service_name); + if (sit == _services.end()) { + LOG_WARN("删除{}服务节点时,没有找到管理对象", service_name); + return; + } + service = sit->second; + } + service->remove(host); + LOG_DEBUG("{}-{} 服务下线节点,进行删除管理!", service_name, host); + } + private: + std::string getServiceName(const std::string &service_instance) { + auto pos = service_instance.find_last_of('/'); + if (pos == std::string::npos) return service_instance; + return service_instance.substr(0, pos); + } + private: + std::mutex _mutex; + std::unordered_set _follow_services; + std::unordered_map _services; +}; +} \ No newline at end of file diff --git a/common/data_es.hpp b/common/data_es.hpp new file mode 100644 index 0000000..99904d4 --- /dev/null +++ b/common/data_es.hpp @@ -0,0 +1,164 @@ +#include "icsearch.hpp" +#include "user.hxx" +#include "message.hxx" + +namespace bite_im { + class ESClientFactory { + public: + static std::shared_ptr create(const std::vector host_list) { + return std::make_shared(host_list); + } + }; + class ESUser { + public: + using ptr = std::shared_ptr; + ESUser(const std::shared_ptr &client): + _es_client(client){} + + bool createIndex() { + bool ret = ESIndex(_es_client, "user") + .append("user_id", "keyword", "standard", true) + .append("nickname") + .append("phone", "keyword", "standard", true) + .append("description", "text", "standard", false) + .append("avatar_id", "keyword", "standard", false) + .create(); + if (ret == false) { + LOG_INFO("用户信息索引创建失败!"); + return false; + } + LOG_INFO("用户信息索引创建成功!"); + return true; + } + + bool appendData(const std::string &uid, + const std::string &phone, + const std::string &nickname, + const std::string &description, + const std::string &avatar_id) { + bool ret = ESInsert(_es_client, "user") + .append("user_id", uid) + .append("nickname", nickname) + .append("phone", phone) + .append("description", description) + .append("avatar_id", avatar_id) + .insert(uid); + if (ret == false) { + LOG_ERROR("用户数据插入/更新失败!"); + return false; + } + LOG_INFO("用户数据新增/更新成功!"); + return true; + } + + std::vector search(const std::string &key, const std::vector &uid_list) { + std::vector res; + Json::Value json_user = ESSearch(_es_client, "user") + .append_should_match("phone.keyword", key) + .append_should_match("user_id.keyword", key) + .append_should_match("nickname", key) + .append_must_not_terms("user_id.keyword", uid_list) + .search(); + if (json_user.isArray() == false) { + LOG_ERROR("用户搜索结果为空,或者结果不是数组类型"); + return res; + } + int sz = json_user.size(); + LOG_DEBUG("检索结果条目数量:{}", sz); + for (int i = 0; i < sz; i++) { + User user; + user.user_id(json_user[i]["_source"]["user_id"].asString()); + user.nickname(json_user[i]["_source"]["nickname"].asString()); + user.description(json_user[i]["_source"]["description"].asString()); + user.phone(json_user[i]["_source"]["phone"].asString()); + user.avatar_id(json_user[i]["_source"]["avatar_id"].asString()); + res.push_back(user); + } + return res; + } + private: + // const std::string _uid_key = "user_id"; + // const std::string _desc_key = "user_id"; + // const std::string _phone_key = "user_id"; + // const std::string _name_key = "user_id"; + // const std::string _avatar_key = "user_id"; + std::shared_ptr _es_client; + }; + + class ESMessage { + public: + using ptr = std::shared_ptr; + ESMessage(const std::shared_ptr &es_client): + _es_client(es_client){} + bool createIndex() { + bool ret = ESIndex(_es_client, "message") + .append("user_id", "keyword", "standard", false) + .append("message_id", "keyword", "standard", false) + .append("create_time", "long", "standard", false) + .append("chat_session_id", "keyword", "standard", true) + .append("content") + .create(); + if (ret == false) { + LOG_INFO("消息信息索引创建失败!"); + return false; + } + LOG_INFO("消息信息索引创建成功!"); + return true; + } + bool appendData(const std::string &user_id, + const std::string &message_id, + const long create_time, + const std::string &chat_session_id, + const std::string &content) { + bool ret = ESInsert(_es_client, "message") + .append("message_id", message_id) + .append("create_time", create_time) + .append("user_id", user_id) + .append("chat_session_id", chat_session_id) + .append("content", content) + .insert(message_id); + if (ret == false) { + LOG_ERROR("消息数据插入/更新失败!"); + return false; + } + LOG_INFO("消息数据新增/更新成功!"); + return true; + } + bool remove(const std::string &mid) { + bool ret = ESRemove(_es_client, "message").remove(mid); + if (ret == false) { + LOG_ERROR("消息数据删除失败!"); + return false; + } + LOG_INFO("消息数据删除成功!"); + return true; + } + std::vector search(const std::string &key, const std::string &ssid) { + std::vector res; + Json::Value json_user = ESSearch(_es_client, "message") + .append_must_term("chat_session_id.keyword", ssid) + .append_must_match("content", key) + .search(); + if (json_user.isArray() == false) { + LOG_ERROR("用户搜索结果为空,或者结果不是数组类型"); + return res; + } + int sz = json_user.size(); + LOG_DEBUG("检索结果条目数量:{}", sz); + for (int i = 0; i < sz; i++) { + bite_im::Message message; + message.user_id(json_user[i]["_source"]["user_id"].asString()); + message.message_id(json_user[i]["_source"]["message_id"].asString()); + boost::posix_time::ptime ctime(boost::posix_time::from_time_t( + json_user[i]["_source"]["create_time"].asInt64())); + message.create_time(ctime); + message.session_id(json_user[i]["_source"]["chat_session_id"].asString()); + message.content(json_user[i]["_source"]["content"].asString()); + res.push_back(message); + } + return res; + } + private: + std::shared_ptr _es_client; + }; +} \ No newline at end of file diff --git a/common/data_redis.hpp b/common/data_redis.hpp new file mode 100644 index 0000000..b82ebef --- /dev/null +++ b/common/data_redis.hpp @@ -0,0 +1,75 @@ +#include +#include + +namespace bite_im { + class RedisClientFactory { + public: + static std::shared_ptr create( + const std::string &host, + int port, + int db, + bool keep_alive) { + sw::redis::ConnectionOptions opts; + opts.host = host; + opts.port = port; + opts.db = db; + opts.keep_alive = keep_alive; + auto res = std::make_shared(opts); + return res; + } + }; + class Session { + public: + using ptr = std::shared_ptr; + Session(const std::shared_ptr &redis_client): + _redis_client(redis_client){} + void append(const std::string &ssid, const std::string &uid) { + _redis_client->set(ssid, uid); + } + void remove(const std::string &ssid) { + _redis_client->del(ssid); + } + sw::redis::OptionalString uid(const std::string &ssid) { + return _redis_client->get(ssid); + } + private: + std::shared_ptr _redis_client; + }; + class Status { + public: + using ptr = std::shared_ptr; + Status(const std::shared_ptr &redis_client): + _redis_client(redis_client){} + void append(const std::string &uid) { + _redis_client->set(uid, ""); + } + void remove(const std::string &uid) { + _redis_client->del(uid); + } + bool exists(const std::string &uid) { + auto res = _redis_client->get(uid); + if (res) return true; + return false; + } + private: + std::shared_ptr _redis_client; + }; + class Codes { + public: + using ptr = std::shared_ptr; + Codes(const std::shared_ptr &redis_client): + _redis_client(redis_client){} + void append(const std::string &cid, const std::string &code, + const std::chrono::milliseconds &t = std::chrono::milliseconds(300000)) { + _redis_client->set(cid, code, t); + } + void remove(const std::string &cid) { + _redis_client->del(cid); + } + sw::redis::OptionalString code(const std::string &cid) { + return _redis_client->get(cid); + } + private: + std::shared_ptr _redis_client; + }; +} \ No newline at end of file diff --git a/common/dms.hpp b/common/dms.hpp new file mode 100644 index 0000000..81b85c6 --- /dev/null +++ b/common/dms.hpp @@ -0,0 +1,46 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "logger.hpp" + +namespace bite_im{ +class DMSClient { + public: + using ptr = std::shared_ptr; + DMSClient(const std::string &access_key_id, + const std::string &access_key_secret) { + AlibabaCloud::InitializeSdk(); + AlibabaCloud::ClientConfiguration configuration( "cn-chengdu" ); + configuration.setConnectTimeout(1500); + configuration.setReadTimeout(4000); + AlibabaCloud::Credentials credential(access_key_id, access_key_secret); + _client = std::make_unique(credential, configuration); + } + ~DMSClient() { AlibabaCloud::ShutdownSdk(); } + bool send(const std::string &phone, const std::string &code) { + AlibabaCloud::CommonRequest request(AlibabaCloud::CommonRequest::RequestPattern::RpcPattern); + request.setHttpMethod(AlibabaCloud::HttpRequest::Method::Post); + request.setDomain("dysmsapi.aliyuncs.com"); + request.setVersion("2017-05-25"); + request.setQueryParameter("Action", "SendSms"); + request.setQueryParameter("SignName", "bitejiuyeke"); + request.setQueryParameter("TemplateCode", "SMS_465324787"); + request.setQueryParameter("PhoneNumbers", phone); + std::string param_code = "{\"code\":\"" + code + "\"}"; + request.setQueryParameter("TemplateParam", param_code); + auto response = _client->commonResponse(request); + if (!response.isSuccess()) { + LOG_ERROR("短信验证码请求失败:{}", response.error().errorMessage()); + return false; + } + return true; + } + private: + std::unique_ptr _client; +}; +} \ No newline at end of file diff --git a/common/etcd.hpp b/common/etcd.hpp new file mode 100644 index 0000000..60efc5d --- /dev/null +++ b/common/etcd.hpp @@ -0,0 +1,83 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include "logger.hpp" + +namespace bite_im{ +//服务注册客户端类 +class Registry { + public: + using ptr = std::shared_ptr; + Registry(const std::string &host): + _client(std::make_shared(host)) , + _keep_alive(_client->leasekeepalive(3).get()), + _lease_id(_keep_alive->Lease()){} + ~Registry() { _keep_alive->Cancel(); } + bool registry(const std::string &key, const std::string &val) { + auto resp = _client->put(key, val, _lease_id).get(); + if (resp.is_ok() == false) { + LOG_ERROR("注册数据失败:{}", resp.error_message()); + return false; + } + return true; + } + private: + std::shared_ptr _client; + std::shared_ptr _keep_alive; + uint64_t _lease_id; +}; + +//服务发现客户端类 +class Discovery { + public: + using ptr = std::shared_ptr; + using NotifyCallback = std::function; + Discovery(const std::string &host, + const std::string &basedir, + const NotifyCallback &put_cb, + const NotifyCallback &del_cb): + _client(std::make_shared(host)) , + _put_cb(put_cb), _del_cb(del_cb){ + //先进行服务发现,先获取到当前已有的数据 + auto resp = _client->ls(basedir).get(); + if (resp.is_ok() == false) { + LOG_ERROR("获取服务信息数据失败:{}", resp.error_message()); + } + int sz = resp.keys().size(); + for (int i = 0; i < sz; ++i) { + if (_put_cb) _put_cb(resp.key(i), resp.value(i).as_string()); + } + //然后进行事件监控,监控数据发生的改变并调用回调进行处理 + _watcher = std::make_shared(*_client.get(), basedir, + std::bind(&Discovery::callback, this, std::placeholders::_1), true); + } + ~Discovery() { + _watcher->Cancel(); + } + private: + void callback(const etcd::Response &resp) { + if (resp.is_ok() == false) { + LOG_ERROR("收到一个错误的事件通知: {}", resp.error_message()); + return; + } + for (auto const& ev : resp.events()) { + if (ev.event_type() == etcd::Event::EventType::PUT) { + if (_put_cb) _put_cb(ev.kv().key(), ev.kv().as_string()); + LOG_DEBUG("新增服务:{}-{}", ev.kv().key(), ev.kv().as_string()); + }else if (ev.event_type() == etcd::Event::EventType::DELETE_) { + if (_del_cb) _del_cb(ev.prev_kv().key(), ev.prev_kv().as_string()); + LOG_DEBUG("下线服务:{}-{}", ev.prev_kv().key(), ev.prev_kv().as_string()); + } + } + } + private: + NotifyCallback _put_cb; + NotifyCallback _del_cb; + std::shared_ptr _client; + std::shared_ptr _watcher; +}; +} \ No newline at end of file diff --git a/common/httplib.h b/common/httplib.h new file mode 100644 index 0000000..0e1d522 --- /dev/null +++ b/common/httplib.h @@ -0,0 +1,10508 @@ +// +// httplib.h +// +// Copyright (c) 2025 Yuji Hirose. All rights reserved. +// MIT License +// + +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +#define CPPHTTPLIB_VERSION "0.20.0" + +/* + * Configuration + */ + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND 10000 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 100 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND +#define CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND +#define CPPHTTPLIB_IDLE_INTERVAL_SECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_USECOND +#ifdef _WIN32 +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 10000 +#else +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 0 +#endif +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_HEADER_MAX_LENGTH +#define CPPHTTPLIB_HEADER_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT +#define CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits::max)()) +#endif + +#ifndef CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_RANGE_MAX_COUNT +#define CPPHTTPLIB_RANGE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_TCP_NODELAY +#define CPPHTTPLIB_TCP_NODELAY false +#endif + +#ifndef CPPHTTPLIB_IPV6_V6ONLY +#define CPPHTTPLIB_IPV6_V6ONLY false +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_COMPRESSION_BUFSIZ +#define CPPHTTPLIB_COMPRESSION_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(8u, std::thread::hardware_concurrency() > 0 \ + ? std::thread::hardware_concurrency() - 1 \ + : 0)) +#endif + +#ifndef CPPHTTPLIB_RECV_FLAGS +#define CPPHTTPLIB_RECV_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_SEND_FLAGS +#define CPPHTTPLIB_SEND_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_LISTEN_BACKLOG +#define CPPHTTPLIB_LISTEN_BACKLOG 5 +#endif + +/* + * Headers + */ + +#ifdef _WIN32 +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS + +#ifndef _CRT_NONSTDC_NO_DEPRECATE +#define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#if _MSC_VER < 1900 +#error Sorry, Visual Studio versions prior to 2015 are not supported +#endif + +#pragma comment(lib, "ws2_32.lib") + +#ifdef _WIN64 +using ssize_t = __int64; +#else +using ssize_t = long; +#endif +#endif // _MSC_VER + +#ifndef S_ISREG +#define S_ISREG(m) (((m) & S_IFREG) == S_IFREG) +#endif // S_ISREG + +#ifndef S_ISDIR +#define S_ISDIR(m) (((m) & S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +#include +#include +#include + +// afunix.h uses types declared in winsock2.h, so has to be included after it. +#include + +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +using nfds_t = unsigned long; +using socket_t = SOCKET; +using socklen_t = int; + +#else // not _WIN32 + +#include +#if !defined(_AIX) && !defined(__MVS__) +#include +#endif +#ifdef __MVS__ +#include +#ifndef NI_MAXHOST +#define NI_MAXHOST 1025 +#endif +#endif +#include +#include +#include +#ifdef __linux__ +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include + +using socket_t = int; +#ifndef INVALID_SOCKET +#define INVALID_SOCKET (-1) +#endif +#endif //_WIN32 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN32 +#include + +// these are defined in wincrypt.h and it breaks compilation if BoringSSL is +// used +#undef X509_NAME +#undef X509_CERT_PAIR +#undef X509_EXTENSIONS +#undef PKCS7_SIGNER_INFO + +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#endif +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#include +#if TARGET_OS_OSX +#include +#include +#endif // TARGET_OS_OSX +#endif // _WIN32 + +#include +#include +#include +#include + +#if defined(_WIN32) && defined(OPENSSL_USE_APPLINK) +#include +#endif + +#include +#include + +#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER) +#if OPENSSL_VERSION_NUMBER < 0x1010107f +#error Please use OpenSSL or a current version of BoringSSL +#endif +#define SSL_get1_peer_certificate SSL_get_peer_certificate +#elif OPENSSL_VERSION_NUMBER < 0x30000000L +#error Sorry, OpenSSL versions prior to 3.0.0 are not supported +#endif + +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +#include +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +#include +#include +#endif + +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +#include +#endif + +/* + * Declaration + */ +namespace httplib { + +namespace detail { + +/* + * Backport std::make_unique from C++14. + * + * NOTE: This code came up with the following stackoverflow post: + * https://stackoverflow.com/questions/10149840/c-arrays-and-make-unique + * + */ + +template +typename std::enable_if::value, std::unique_ptr>::type +make_unique(Args &&...args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +template +typename std::enable_if::value, std::unique_ptr>::type +make_unique(std::size_t n) { + typedef typename std::remove_extent::type RT; + return std::unique_ptr(new RT[n]); +} + +namespace case_ignore { + +inline unsigned char to_lower(int c) { + const static unsigned char table[256] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, + 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, + 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, + 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, + 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, + 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, + 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 224, 225, 226, + 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, + 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, 224, + 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, + 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, + 255, + }; + return table[(unsigned char)(char)c]; +} + +inline bool equal(const std::string &a, const std::string &b) { + return a.size() == b.size() && + std::equal(a.begin(), a.end(), b.begin(), [](char ca, char cb) { + return to_lower(ca) == to_lower(cb); + }); +} + +struct equal_to { + bool operator()(const std::string &a, const std::string &b) const { + return equal(a, b); + } +}; + +struct hash { + size_t operator()(const std::string &key) const { + return hash_core(key.data(), key.size(), 0); + } + + size_t hash_core(const char *s, size_t l, size_t h) const { + return (l == 0) ? h + : hash_core(s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no + // overflow happens + (((std::numeric_limits::max)() >> 6) & + h * 33) ^ + static_cast(to_lower(*s))); + } +}; + +} // namespace case_ignore + +// This is based on +// "http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n4189". + +struct scope_exit { + explicit scope_exit(std::function &&f) + : exit_function(std::move(f)), execute_on_destruction{true} {} + + scope_exit(scope_exit &&rhs) noexcept + : exit_function(std::move(rhs.exit_function)), + execute_on_destruction{rhs.execute_on_destruction} { + rhs.release(); + } + + ~scope_exit() { + if (execute_on_destruction) { this->exit_function(); } + } + + void release() { this->execute_on_destruction = false; } + +private: + scope_exit(const scope_exit &) = delete; + void operator=(const scope_exit &) = delete; + scope_exit &operator=(scope_exit &&) = delete; + + std::function exit_function; + bool execute_on_destruction; +}; + +} // namespace detail + +enum SSLVerifierResponse { + // no decision has been made, use the built-in certificate verifier + NoDecisionMade, + // connection certificate is verified and accepted + CertificateAccepted, + // connection certificate was processed but is rejected + CertificateRejected +}; + +enum StatusCode { + // Information responses + Continue_100 = 100, + SwitchingProtocol_101 = 101, + Processing_102 = 102, + EarlyHints_103 = 103, + + // Successful responses + OK_200 = 200, + Created_201 = 201, + Accepted_202 = 202, + NonAuthoritativeInformation_203 = 203, + NoContent_204 = 204, + ResetContent_205 = 205, + PartialContent_206 = 206, + MultiStatus_207 = 207, + AlreadyReported_208 = 208, + IMUsed_226 = 226, + + // Redirection messages + MultipleChoices_300 = 300, + MovedPermanently_301 = 301, + Found_302 = 302, + SeeOther_303 = 303, + NotModified_304 = 304, + UseProxy_305 = 305, + unused_306 = 306, + TemporaryRedirect_307 = 307, + PermanentRedirect_308 = 308, + + // Client error responses + BadRequest_400 = 400, + Unauthorized_401 = 401, + PaymentRequired_402 = 402, + Forbidden_403 = 403, + NotFound_404 = 404, + MethodNotAllowed_405 = 405, + NotAcceptable_406 = 406, + ProxyAuthenticationRequired_407 = 407, + RequestTimeout_408 = 408, + Conflict_409 = 409, + Gone_410 = 410, + LengthRequired_411 = 411, + PreconditionFailed_412 = 412, + PayloadTooLarge_413 = 413, + UriTooLong_414 = 414, + UnsupportedMediaType_415 = 415, + RangeNotSatisfiable_416 = 416, + ExpectationFailed_417 = 417, + ImATeapot_418 = 418, + MisdirectedRequest_421 = 421, + UnprocessableContent_422 = 422, + Locked_423 = 423, + FailedDependency_424 = 424, + TooEarly_425 = 425, + UpgradeRequired_426 = 426, + PreconditionRequired_428 = 428, + TooManyRequests_429 = 429, + RequestHeaderFieldsTooLarge_431 = 431, + UnavailableForLegalReasons_451 = 451, + + // Server error responses + InternalServerError_500 = 500, + NotImplemented_501 = 501, + BadGateway_502 = 502, + ServiceUnavailable_503 = 503, + GatewayTimeout_504 = 504, + HttpVersionNotSupported_505 = 505, + VariantAlsoNegotiates_506 = 506, + InsufficientStorage_507 = 507, + LoopDetected_508 = 508, + NotExtended_510 = 510, + NetworkAuthenticationRequired_511 = 511, +}; + +using Headers = + std::unordered_multimap; + +using Params = std::multimap; +using Match = std::smatch; + +using Progress = std::function; + +struct Response; +using ResponseHandler = std::function; + +struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; +}; +using MultipartFormDataItems = std::vector; +using MultipartFormDataMap = std::multimap; + +class DataSink { +public: + DataSink() : os(&sb_), sb_(*this) {} + + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; + + std::function write; + std::function is_writable; + std::function done; + std::function done_with_trailer; + std::ostream os; + +private: + class data_sink_streambuf final : public std::streambuf { + public: + explicit data_sink_streambuf(DataSink &sink) : sink_(sink) {} + + protected: + std::streamsize xsputn(const char *s, std::streamsize n) override { + sink_.write(s, static_cast(n)); + return n; + } + + private: + DataSink &sink_; + }; + + data_sink_streambuf sb_; +}; + +using ContentProvider = + std::function; + +using ContentProviderWithoutLength = + std::function; + +using ContentProviderResourceReleaser = std::function; + +struct MultipartFormDataProvider { + std::string name; + ContentProviderWithoutLength provider; + std::string filename; + std::string content_type; +}; +using MultipartFormDataProviderItems = std::vector; + +using ContentReceiverWithProgress = + std::function; + +using ContentReceiver = + std::function; + +using MultipartContentHeader = + std::function; + +class ContentReader { +public: + using Reader = std::function; + using MultipartReader = std::function; + + ContentReader(Reader reader, MultipartReader multipart_reader) + : reader_(std::move(reader)), + multipart_reader_(std::move(multipart_reader)) {} + + bool operator()(MultipartContentHeader header, + ContentReceiver receiver) const { + return multipart_reader_(std::move(header), std::move(receiver)); + } + + bool operator()(ContentReceiver receiver) const { + return reader_(std::move(receiver)); + } + + Reader reader_; + MultipartReader multipart_reader_; +}; + +using Range = std::pair; +using Ranges = std::vector; + +struct Request { + std::string method; + std::string path; + Params params; + Headers headers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + std::string local_addr; + int local_port = -1; + + // for server + std::string version; + std::string target; + MultipartFormDataMap files; + Ranges ranges; + Match matches; + std::unordered_map path_params; + std::function is_connection_closed = []() { return true; }; + + // for client + ResponseHandler response_handler; + ContentReceiverWithProgress content_receiver; + Progress progress; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl = nullptr; +#endif + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, const char *def = "", + size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, uint64_t def = 0, + size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + bool has_param(const std::string &key) const; + std::string get_param_value(const std::string &key, size_t id = 0) const; + size_t get_param_value_count(const std::string &key) const; + + bool is_multipart_form_data() const; + + bool has_file(const std::string &key) const; + MultipartFormData get_file_value(const std::string &key) const; + std::vector get_file_values(const std::string &key) const; + + // private members... + size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT; + size_t content_length_ = 0; + ContentProvider content_provider_; + bool is_chunked_content_provider_ = false; + size_t authorization_count_ = 0; + std::chrono::time_point start_time_ = + (std::chrono::steady_clock::time_point::min)(); +}; + +struct Response { + std::string version; + int status = -1; + std::string reason; + Headers headers; + std::string body; + std::string location; // Redirect location + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, const char *def = "", + size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, uint64_t def = 0, + size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + void set_redirect(const std::string &url, int status = StatusCode::Found_302); + void set_content(const char *s, size_t n, const std::string &content_type); + void set_content(const std::string &s, const std::string &content_type); + void set_content(std::string &&s, const std::string &content_type); + + void set_content_provider( + size_t length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_chunked_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_file_content(const std::string &path, + const std::string &content_type); + void set_file_content(const std::string &path); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser_) { + content_provider_resource_releaser_(content_provider_success_); + } + } + + // private members... + size_t content_length_ = 0; + ContentProvider content_provider_; + ContentProviderResourceReleaser content_provider_resource_releaser_; + bool is_chunked_content_provider_ = false; + bool content_provider_success_ = false; + std::string file_content_path_; + std::string file_content_content_type_; +}; + +class Stream { +public: + virtual ~Stream() = default; + + virtual bool is_readable() const = 0; + virtual bool wait_readable() const = 0; + virtual bool wait_writable() const = 0; + + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0; + virtual socket_t socket() const = 0; + + virtual time_t duration() const = 0; + + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); +}; + +class TaskQueue { +public: + TaskQueue() = default; + virtual ~TaskQueue() = default; + + virtual bool enqueue(std::function fn) = 0; + virtual void shutdown() = 0; + + virtual void on_idle() {} +}; + +class ThreadPool final : public TaskQueue { +public: + explicit ThreadPool(size_t n, size_t mqr = 0) + : shutdown_(false), max_queued_requests_(mqr) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + bool enqueue(std::function fn) override { + { + std::unique_lock lock(mutex_); + if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { + return false; + } + jobs_.push_back(std::move(fn)); + } + + cond_.notify_one(); + return true; + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } + +private: + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ + !defined(LIBRESSL_VERSION_NUMBER) + OPENSSL_thread_stop(); +#endif + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector threads_; + std::list> jobs_; + + bool shutdown_; + size_t max_queued_requests_ = 0; + + std::condition_variable cond_; + std::mutex mutex_; +}; + +using Logger = std::function; + +using SocketOptions = std::function; + +namespace detail { + +bool set_socket_opt_impl(socket_t sock, int level, int optname, + const void *optval, socklen_t optlen); +bool set_socket_opt(socket_t sock, int level, int optname, int opt); +bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec, + time_t usec); + +} // namespace detail + +void default_socket_options(socket_t sock); + +const char *status_message(int status); + +std::string get_bearer_token_auth(const Request &req); + +namespace detail { + +class MatcherBase { +public: + virtual ~MatcherBase() = default; + + // Match request path and populate its matches and + virtual bool match(Request &request) const = 0; +}; + +/** + * Captures parameters in request path and stores them in Request::path_params + * + * Capture name is a substring of a pattern from : to /. + * The rest of the pattern is matched against the request path directly + * Parameters are captured starting from the next character after + * the end of the last matched static pattern fragment until the next /. + * + * Example pattern: + * "/path/fragments/:capture/more/fragments/:second_capture" + * Static fragments: + * "/path/fragments/", "more/fragments/" + * + * Given the following request path: + * "/path/fragments/:1/more/fragments/:2" + * the resulting capture will be + * {{"capture", "1"}, {"second_capture", "2"}} + */ +class PathParamsMatcher final : public MatcherBase { +public: + PathParamsMatcher(const std::string &pattern); + + bool match(Request &request) const override; + +private: + // Treat segment separators as the end of path parameter capture + // Does not need to handle query parameters as they are parsed before path + // matching + static constexpr char separator = '/'; + + // Contains static path fragments to match against, excluding the '/' after + // path params + // Fragments are separated by path params + std::vector static_fragments_; + // Stores the names of the path parameters to be used as keys in the + // Request::path_params map + std::vector param_names_; +}; + +/** + * Performs std::regex_match on request path + * and stores the result in Request::matches + * + * Note that regex match is performed directly on the whole request. + * This means that wildcard patterns may match multiple path segments with /: + * "/begin/(.*)/end" will match both "/begin/middle/end" and "/begin/1/2/end". + */ +class RegexMatcher final : public MatcherBase { +public: + RegexMatcher(const std::string &pattern) : regex_(pattern) {} + + bool match(Request &request) const override; + +private: + std::regex regex_; +}; + +ssize_t write_headers(Stream &strm, const Headers &headers); + +} // namespace detail + +class Server { +public: + using Handler = std::function; + + using ExceptionHandler = + std::function; + + enum class HandlerResponse { + Handled, + Unhandled, + }; + using HandlerWithResponse = + std::function; + + using HandlerWithContentReader = std::function; + + using Expect100ContinueHandler = + std::function; + + Server(); + + virtual ~Server(); + + virtual bool is_valid() const; + + Server &Get(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, HandlerWithContentReader handler); + Server &Put(const std::string &pattern, Handler handler); + Server &Put(const std::string &pattern, HandlerWithContentReader handler); + Server &Patch(const std::string &pattern, Handler handler); + Server &Patch(const std::string &pattern, HandlerWithContentReader handler); + Server &Delete(const std::string &pattern, Handler handler); + Server &Delete(const std::string &pattern, HandlerWithContentReader handler); + Server &Options(const std::string &pattern, Handler handler); + + bool set_base_dir(const std::string &dir, + const std::string &mount_point = std::string()); + bool set_mount_point(const std::string &mount_point, const std::string &dir, + Headers headers = Headers()); + bool remove_mount_point(const std::string &mount_point); + Server &set_file_extension_and_mimetype_mapping(const std::string &ext, + const std::string &mime); + Server &set_default_file_mimetype(const std::string &mime); + Server &set_file_request_handler(Handler handler); + + template + Server &set_error_handler(ErrorHandlerFunc &&handler) { + return set_error_handler_core( + std::forward(handler), + std::is_convertible{}); + } + + Server &set_exception_handler(ExceptionHandler handler); + Server &set_pre_routing_handler(HandlerWithResponse handler); + Server &set_post_routing_handler(Handler handler); + + Server &set_expect_100_continue_handler(Expect100ContinueHandler handler); + Server &set_logger(Logger logger); + + Server &set_address_family(int family); + Server &set_tcp_nodelay(bool on); + Server &set_ipv6_v6only(bool on); + Server &set_socket_options(SocketOptions socket_options); + + Server &set_default_headers(Headers headers); + Server & + set_header_writer(std::function const &writer); + + Server &set_keep_alive_max_count(size_t count); + Server &set_keep_alive_timeout(time_t sec); + + Server &set_read_timeout(time_t sec, time_t usec = 0); + template + Server &set_read_timeout(const std::chrono::duration &duration); + + Server &set_write_timeout(time_t sec, time_t usec = 0); + template + Server &set_write_timeout(const std::chrono::duration &duration); + + Server &set_idle_interval(time_t sec, time_t usec = 0); + template + Server &set_idle_interval(const std::chrono::duration &duration); + + Server &set_payload_max_length(size_t length); + + bool bind_to_port(const std::string &host, int port, int socket_flags = 0); + int bind_to_any_port(const std::string &host, int socket_flags = 0); + bool listen_after_bind(); + + bool listen(const std::string &host, int port, int socket_flags = 0); + + bool is_running() const; + void wait_until_ready() const; + void stop(); + void decommission(); + + std::function new_task_queue; + +protected: + bool process_request(Stream &strm, const std::string &remote_addr, + int remote_port, const std::string &local_addr, + int local_port, bool close_connection, + bool &connection_closed, + const std::function &setup_request); + + std::atomic svr_sock_{INVALID_SOCKET}; + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND; + time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; + time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + +private: + using Handlers = + std::vector, Handler>>; + using HandlersForContentReader = + std::vector, + HandlerWithContentReader>>; + + static std::unique_ptr + make_matcher(const std::string &pattern); + + Server &set_error_handler_core(HandlerWithResponse handler, std::true_type); + Server &set_error_handler_core(Handler handler, std::false_type); + + socket_t create_server_socket(const std::string &host, int port, + int socket_flags, + SocketOptions socket_options) const; + int bind_internal(const std::string &host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(const Request &req, Response &res, + bool head = false); + bool dispatch_request(Request &req, Response &res, + const Handlers &handlers) const; + bool dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const; + + bool parse_request_line(const char *s, Request &req) const; + void apply_ranges(const Request &req, Response &res, + std::string &content_type, std::string &boundary) const; + bool write_response(Stream &strm, bool close_connection, Request &req, + Response &res); + bool write_response_with_content(Stream &strm, bool close_connection, + const Request &req, Response &res); + bool write_response_core(Stream &strm, bool close_connection, + const Request &req, Response &res, + bool need_apply_ranges); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool + read_content_with_content_receiver(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) const; + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_{false}; + std::atomic is_decommissioned{false}; + + struct MountPointEntry { + std::string mount_point; + std::string base_dir; + Headers headers; + }; + std::vector base_dirs_; + std::map file_extension_and_mimetype_map_; + std::string default_file_mimetype_ = "application/octet-stream"; + Handler file_request_handler_; + + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + + HandlerWithResponse error_handler_; + ExceptionHandler exception_handler_; + HandlerWithResponse pre_routing_handler_; + Handler post_routing_handler_; + Expect100ContinueHandler expect_100_continue_handler_; + + Logger logger_; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; + SocketOptions socket_options_ = default_socket_options; + + Headers default_headers_; + std::function header_writer_ = + detail::write_headers; +}; + +enum class Error { + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification, + SSLServerHostnameVerification, + UnsupportedMultipartBoundaryChars, + Compression, + ConnectionTimeout, + ProxyConnection, + + // For internal use only + SSLPeerCouldBeClosed_, +}; + +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + +class Result { +public: + Result() = default; + Result(std::unique_ptr &&res, Error err, + Headers &&request_headers = Headers{}) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)) {} + // Response + operator bool() const { return res_ != nullptr; } + bool operator==(std::nullptr_t) const { return res_ == nullptr; } + bool operator!=(std::nullptr_t) const { return res_ != nullptr; } + const Response &value() const { return *res_; } + Response &value() { return *res_; } + const Response &operator*() const { return *res_; } + Response &operator*() { return *res_; } + const Response *operator->() const { return res_.get(); } + Response *operator->() { return res_.get(); } + + // Error + Error error() const { return err_; } + + // Request Headers + bool has_request_header(const std::string &key) const; + std::string get_request_header_value(const std::string &key, + const char *def = "", + size_t id = 0) const; + uint64_t get_request_header_value_u64(const std::string &key, + uint64_t def = 0, size_t id = 0) const; + size_t get_request_header_value_count(const std::string &key) const; + +private: + std::unique_ptr res_; + Error err_ = Error::Unknown; + Headers request_headers_; +}; + +class ClientImpl { +public: + explicit ClientImpl(const std::string &host); + + explicit ClientImpl(const std::string &host, int port); + + explicit ClientImpl(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + virtual ~ClientImpl(); + + virtual bool is_valid() const; + + Result Get(const std::string &path); + Result Get(const std::string &path, const Headers &headers); + Result Get(const std::string &path, Progress progress); + Result Get(const std::string &path, const Headers &headers, + Progress progress); + Result Get(const std::string &path, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver); + Result Get(const std::string &path, ContentReceiver content_receiver, + Progress progress); + Result Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, ContentReceiver content_receiver, + Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + Result Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); + Result Post(const std::string &path, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + Result Put(const std::string &path, size_t content_length, + ContentProvider content_provider, const std::string &content_type); + Result Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); + Result Put(const std::string &path, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + Result Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + + Result Delete(const std::string &path); + Result Delete(const std::string &path, const Headers &headers); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void + set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_ipv6_v6only(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template + void + set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template + void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template + void set_write_timeout(const std::chrono::duration &duration); + + void set_max_timeout(time_t msec); + template + void set_max_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, + const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, + const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, + const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + void set_ca_cert_store(X509_STORE *ca_cert_store); + X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier( + std::function verifier); +#endif + + void set_logger(Logger logger); + +protected: + struct Socket { + socket_t sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSL *ssl = nullptr; +#endif + + bool is_open() const { return sock != INVALID_SOCKET; } + }; + + virtual bool create_and_connect_socket(Socket &socket, Error &error); + + // All of: + // shutdown_ssl + // shutdown_socket + // close_socket + // should ONLY be called when socket_mutex_ is locked. + // Also, shutdown_ssl and close_socket should also NOT be called concurrently + // with a DIFFERENT thread sending requests using that socket. + virtual void shutdown_ssl(Socket &socket, bool shutdown_gracefully); + void shutdown_socket(Socket &socket) const; + void close_socket(Socket &socket); + + bool process_request(Stream &strm, Request &req, Response &res, + bool close_connection, Error &error); + + bool write_content_with_provider(Stream &strm, const Request &req, + Error &error) const; + + void copy_settings(const ClientImpl &rhs); + + // Socket endpoint information + const std::string host_; + const int port_; + const std::string host_and_port_; + + // Current open socket + Socket socket_; + mutable std::mutex socket_mutex_; + std::recursive_mutex request_mutex_; + + // These are all protected under socket_mutex + size_t socket_requests_in_flight_ = 0; + std::thread::id socket_requests_are_from_thread_ = std::thread::id(); + bool socket_should_be_closed_when_request_is_done_ = false; + + // Hostname-IP map + std::map addr_map_; + + // Default headers + Headers default_headers_; + + // Header writer + std::function header_writer_ = + detail::write_headers; + + // Settings + std::string client_cert_path_; + std::string client_key_path_; + + time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; + time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; + time_t max_timeout_msec_ = CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND; + + std::string basic_auth_username_; + std::string basic_auth_password_; + std::string bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string digest_auth_username_; + std::string digest_auth_password_; +#endif + + bool keep_alive_ = false; + bool follow_location_ = false; + + bool url_encode_ = true; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; + SocketOptions socket_options_ = nullptr; + + bool compress_ = false; + bool decompress_ = true; + + std::string interface_; + + std::string proxy_host_; + int proxy_port_ = -1; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; + std::string proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + + X509_STORE *ca_cert_store_ = nullptr; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool server_certificate_verification_ = true; + bool server_hostname_verification_ = true; + std::function server_certificate_verifier_; +#endif + + Logger logger_; + +private: + bool send_(Request &req, Response &res, Error &error); + Result send_(Request &&req); + + socket_t create_client_socket(Error &error) const; + bool read_response_line(Stream &strm, const Request &req, + Response &res) const; + bool write_request(Stream &strm, Request &req, bool close_connection, + Error &error); + bool redirect(Request &req, Response &res, Error &error); + bool handle_request(Stream &strm, Request &req, Response &res, + bool close_connection, Error &error); + std::unique_ptr send_with_content_provider( + Request &req, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Error &error); + Result send_with_content_provider( + const std::string &method, const std::string &path, + const Headers &headers, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Progress progress); + ContentProviderWithoutLength get_multipart_content_provider( + const std::string &boundary, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) const; + + std::string adjust_host_string(const std::string &host) const; + + virtual bool + process_socket(const Socket &socket, + std::chrono::time_point start_time, + std::function callback); + virtual bool is_ssl() const; +}; + +class Client { +public: + // Universal interface + explicit Client(const std::string &scheme_host_port); + + explicit Client(const std::string &scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path); + + // HTTP only interface + explicit Client(const std::string &host, int port); + + explicit Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + Client(Client &&) = default; + Client &operator=(Client &&) = default; + + ~Client(); + + bool is_valid() const; + + Result Get(const std::string &path); + Result Get(const std::string &path, const Headers &headers); + Result Get(const std::string &path, Progress progress); + Result Get(const std::string &path, const Headers &headers, + Progress progress); + Result Get(const std::string &path, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver); + Result Get(const std::string &path, ContentReceiver content_receiver, + Progress progress); + Result Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, ContentReceiver content_receiver, + Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + Result Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); + Result Post(const std::string &path, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + Result Put(const std::string &path, size_t content_length, + ContentProvider content_provider, const std::string &content_type); + Result Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); + Result Put(const std::string &path, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + Result Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + + Result Delete(const std::string &path); + Result Delete(const std::string &path, const Headers &headers); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void + set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template + void + set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template + void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template + void set_write_timeout(const std::chrono::duration &duration); + + void set_max_timeout(time_t msec); + template + void set_max_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, + const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, + const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, + const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier( + std::function verifier); +#endif + + void set_logger(Logger logger); + + // SSL +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; +#endif + +private: + std::unique_ptr cli_; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool is_ssl_ = false; +#endif +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLServer : public Server { +public: + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr, + const char *private_key_password = nullptr); + + SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + + SSLServer( + const std::function &setup_ssl_ctx_callback); + + ~SSLServer() override; + + bool is_valid() const override; + + SSL_CTX *ssl_context() const; + + void update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + +private: + bool process_and_close_socket(socket_t sock) override; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; +}; + +class SSLClient final : public ClientImpl { +public: + explicit SSLClient(const std::string &host); + + explicit SSLClient(const std::string &host, int port); + + explicit SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path, + const std::string &private_key_password = std::string()); + + explicit SSLClient(const std::string &host, int port, X509 *client_cert, + EVP_PKEY *client_key, + const std::string &private_key_password = std::string()); + + ~SSLClient() override; + + bool is_valid() const override; + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; + +private: + bool create_and_connect_socket(Socket &socket, Error &error) override; + void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; + void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully); + + bool + process_socket(const Socket &socket, + std::chrono::time_point start_time, + std::function callback) override; + bool is_ssl() const override; + + bool connect_with_proxy( + Socket &sock, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error); + bool initialize_ssl(Socket &socket, Error &error); + + bool load_certs(); + + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::once_flag initialize_cert_; + + std::vector host_components_; + + long verify_result_ = 0; + + friend class ClientImpl; +}; +#endif + +/* + * Implementation of template methods. + */ + +namespace detail { + +template +inline void duration_to_sec_and_usec(const T &duration, U callback) { + auto sec = std::chrono::duration_cast(duration).count(); + auto usec = std::chrono::duration_cast( + duration - std::chrono::seconds(sec)) + .count(); + callback(static_cast(sec), static_cast(usec)); +} + +template inline constexpr size_t str_len(const char (&)[N]) { + return N - 1; +} + +inline bool is_numeric(const std::string &str) { + return !str.empty() && std::all_of(str.begin(), str.end(), ::isdigit); +} + +inline uint64_t get_header_value_u64(const Headers &headers, + const std::string &key, uint64_t def, + size_t id, bool &is_invalid_value) { + is_invalid_value = false; + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + if (is_numeric(it->second)) { + return std::strtoull(it->second.data(), nullptr, 10); + } else { + is_invalid_value = true; + } + } + return def; +} + +inline uint64_t get_header_value_u64(const Headers &headers, + const std::string &key, uint64_t def, + size_t id) { + bool dummy = false; + return get_header_value_u64(headers, key, def, id, dummy); +} + +} // namespace detail + +inline uint64_t Request::get_header_value_u64(const std::string &key, + uint64_t def, size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + +inline uint64_t Response::get_header_value_u64(const std::string &key, + uint64_t def, size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + +namespace detail { + +inline bool set_socket_opt_impl(socket_t sock, int level, int optname, + const void *optval, socklen_t optlen) { + return setsockopt(sock, level, optname, +#ifdef _WIN32 + reinterpret_cast(optval), +#else + optval, +#endif + optlen) == 0; +} + +inline bool set_socket_opt(socket_t sock, int level, int optname, int optval) { + return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval)); +} + +inline bool set_socket_opt_time(socket_t sock, int level, int optname, + time_t sec, time_t usec) { +#ifdef _WIN32 + auto timeout = static_cast(sec * 1000 + usec / 1000); +#else + timeval timeout; + timeout.tv_sec = static_cast(sec); + timeout.tv_usec = static_cast(usec); +#endif + return set_socket_opt_impl(sock, level, optname, &timeout, sizeof(timeout)); +} + +} // namespace detail + +inline void default_socket_options(socket_t sock) { + detail::set_socket_opt(sock, SOL_SOCKET, +#ifdef SO_REUSEPORT + SO_REUSEPORT, +#else + SO_REUSEADDR, +#endif + 1); +} + +inline const char *status_message(int status) { + switch (status) { + case StatusCode::Continue_100: return "Continue"; + case StatusCode::SwitchingProtocol_101: return "Switching Protocol"; + case StatusCode::Processing_102: return "Processing"; + case StatusCode::EarlyHints_103: return "Early Hints"; + case StatusCode::OK_200: return "OK"; + case StatusCode::Created_201: return "Created"; + case StatusCode::Accepted_202: return "Accepted"; + case StatusCode::NonAuthoritativeInformation_203: + return "Non-Authoritative Information"; + case StatusCode::NoContent_204: return "No Content"; + case StatusCode::ResetContent_205: return "Reset Content"; + case StatusCode::PartialContent_206: return "Partial Content"; + case StatusCode::MultiStatus_207: return "Multi-Status"; + case StatusCode::AlreadyReported_208: return "Already Reported"; + case StatusCode::IMUsed_226: return "IM Used"; + case StatusCode::MultipleChoices_300: return "Multiple Choices"; + case StatusCode::MovedPermanently_301: return "Moved Permanently"; + case StatusCode::Found_302: return "Found"; + case StatusCode::SeeOther_303: return "See Other"; + case StatusCode::NotModified_304: return "Not Modified"; + case StatusCode::UseProxy_305: return "Use Proxy"; + case StatusCode::unused_306: return "unused"; + case StatusCode::TemporaryRedirect_307: return "Temporary Redirect"; + case StatusCode::PermanentRedirect_308: return "Permanent Redirect"; + case StatusCode::BadRequest_400: return "Bad Request"; + case StatusCode::Unauthorized_401: return "Unauthorized"; + case StatusCode::PaymentRequired_402: return "Payment Required"; + case StatusCode::Forbidden_403: return "Forbidden"; + case StatusCode::NotFound_404: return "Not Found"; + case StatusCode::MethodNotAllowed_405: return "Method Not Allowed"; + case StatusCode::NotAcceptable_406: return "Not Acceptable"; + case StatusCode::ProxyAuthenticationRequired_407: + return "Proxy Authentication Required"; + case StatusCode::RequestTimeout_408: return "Request Timeout"; + case StatusCode::Conflict_409: return "Conflict"; + case StatusCode::Gone_410: return "Gone"; + case StatusCode::LengthRequired_411: return "Length Required"; + case StatusCode::PreconditionFailed_412: return "Precondition Failed"; + case StatusCode::PayloadTooLarge_413: return "Payload Too Large"; + case StatusCode::UriTooLong_414: return "URI Too Long"; + case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type"; + case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable"; + case StatusCode::ExpectationFailed_417: return "Expectation Failed"; + case StatusCode::ImATeapot_418: return "I'm a teapot"; + case StatusCode::MisdirectedRequest_421: return "Misdirected Request"; + case StatusCode::UnprocessableContent_422: return "Unprocessable Content"; + case StatusCode::Locked_423: return "Locked"; + case StatusCode::FailedDependency_424: return "Failed Dependency"; + case StatusCode::TooEarly_425: return "Too Early"; + case StatusCode::UpgradeRequired_426: return "Upgrade Required"; + case StatusCode::PreconditionRequired_428: return "Precondition Required"; + case StatusCode::TooManyRequests_429: return "Too Many Requests"; + case StatusCode::RequestHeaderFieldsTooLarge_431: + return "Request Header Fields Too Large"; + case StatusCode::UnavailableForLegalReasons_451: + return "Unavailable For Legal Reasons"; + case StatusCode::NotImplemented_501: return "Not Implemented"; + case StatusCode::BadGateway_502: return "Bad Gateway"; + case StatusCode::ServiceUnavailable_503: return "Service Unavailable"; + case StatusCode::GatewayTimeout_504: return "Gateway Timeout"; + case StatusCode::HttpVersionNotSupported_505: + return "HTTP Version Not Supported"; + case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates"; + case StatusCode::InsufficientStorage_507: return "Insufficient Storage"; + case StatusCode::LoopDetected_508: return "Loop Detected"; + case StatusCode::NotExtended_510: return "Not Extended"; + case StatusCode::NetworkAuthenticationRequired_511: + return "Network Authentication Required"; + + default: + case StatusCode::InternalServerError_500: return "Internal Server Error"; + } +} + +inline std::string get_bearer_token_auth(const Request &req) { + if (req.has_header("Authorization")) { + constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); + return req.get_header_value("Authorization") + .substr(bearer_header_prefix_len); + } + return ""; +} + +template +inline Server & +Server::set_read_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); + return *this; +} + +template +inline Server & +Server::set_write_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); + return *this; +} + +template +inline Server & +Server::set_idle_interval(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_idle_interval(sec, usec); }); + return *this; +} + +inline std::string to_string(const Error error) { + switch (error) { + case Error::Success: return "Success (no error)"; + case Error::Connection: return "Could not establish connection"; + case Error::BindIPAddress: return "Failed to bind IP address"; + case Error::Read: return "Failed to read connection"; + case Error::Write: return "Failed to write connection"; + case Error::ExceedRedirectCount: return "Maximum redirect count exceeded"; + case Error::Canceled: return "Connection handling canceled"; + case Error::SSLConnection: return "SSL connection failed"; + case Error::SSLLoadingCerts: return "SSL certificate loading failed"; + case Error::SSLServerVerification: return "SSL server verification failed"; + case Error::SSLServerHostnameVerification: + return "SSL server hostname verification failed"; + case Error::UnsupportedMultipartBoundaryChars: + return "Unsupported HTTP multipart boundary characters"; + case Error::Compression: return "Compression failed"; + case Error::ConnectionTimeout: return "Connection timed out"; + case Error::ProxyConnection: return "Proxy connection failed"; + case Error::Unknown: return "Unknown"; + default: break; + } + + return "Invalid"; +} + +inline std::ostream &operator<<(std::ostream &os, const Error &obj) { + os << to_string(obj); + os << " (" << static_cast::type>(obj) << ')'; + return os; +} + +inline uint64_t Result::get_request_header_value_u64(const std::string &key, + uint64_t def, + size_t id) const { + return detail::get_header_value_u64(request_headers_, key, def, id); +} + +template +inline void ClientImpl::set_connection_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { + set_connection_timeout(sec, usec); + }); +} + +template +inline void ClientImpl::set_read_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_write_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_max_timeout( + const std::chrono::duration &duration) { + auto msec = + std::chrono::duration_cast(duration).count(); + set_max_timeout(msec); +} + +template +inline void Client::set_connection_timeout( + const std::chrono::duration &duration) { + cli_->set_connection_timeout(duration); +} + +template +inline void +Client::set_read_timeout(const std::chrono::duration &duration) { + cli_->set_read_timeout(duration); +} + +template +inline void +Client::set_write_timeout(const std::chrono::duration &duration) { + cli_->set_write_timeout(duration); +} + +template +inline void +Client::set_max_timeout(const std::chrono::duration &duration) { + cli_->set_max_timeout(duration); +} + +/* + * Forward declarations and types that will be part of the .h file if split into + * .h + .cc. + */ + +std::string hosted_at(const std::string &hostname); + +void hosted_at(const std::string &hostname, std::vector &addrs); + +std::string append_query_params(const std::string &path, const Params ¶ms); + +std::pair make_range_header(const Ranges &ranges); + +std::pair +make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy = false); + +namespace detail { + +#if defined(_WIN32) +inline std::wstring u8string_to_wstring(const char *s) { + std::wstring ws; + auto len = static_cast(strlen(s)); + auto wlen = ::MultiByteToWideChar(CP_UTF8, 0, s, len, nullptr, 0); + if (wlen > 0) { + ws.resize(wlen); + wlen = ::MultiByteToWideChar( + CP_UTF8, 0, s, len, + const_cast(reinterpret_cast(ws.data())), wlen); + if (wlen != static_cast(ws.size())) { ws.clear(); } + } + return ws; +} +#endif + +struct FileStat { + FileStat(const std::string &path); + bool is_file() const; + bool is_dir() const; + +private: +#if defined(_WIN32) + struct _stat st_; +#else + struct stat st_; +#endif + int ret_ = -1; +}; + +std::string encode_query_param(const std::string &value); + +std::string decode_url(const std::string &s, bool convert_plus_to_space); + +std::string trim_copy(const std::string &s); + +void divide( + const char *data, std::size_t size, char d, + std::function + fn); + +void divide( + const std::string &str, char d, + std::function + fn); + +void split(const char *b, const char *e, char d, + std::function fn); + +void split(const char *b, const char *e, char d, size_t m, + std::function fn); + +bool process_client_socket( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time, + std::function callback); + +socket_t create_client_socket(const std::string &host, const std::string &ip, + int port, int address_family, bool tcp_nodelay, + bool ipv6_v6only, SocketOptions socket_options, + time_t connection_timeout_sec, + time_t connection_timeout_usec, + time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + const std::string &intf, Error &error); + +const char *get_header_value(const Headers &headers, const std::string &key, + const char *def, size_t id); + +std::string params_to_query_str(const Params ¶ms); + +void parse_query_text(const char *data, std::size_t size, Params ¶ms); + +void parse_query_text(const std::string &s, Params ¶ms); + +bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary); + +bool parse_range_header(const std::string &s, Ranges &ranges); + +int close_socket(socket_t sock); + +ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); + +ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); + +enum class EncodingType { None = 0, Gzip, Brotli, Zstd }; + +EncodingType encoding_type(const Request &req, const Response &res); + +class BufferStream final : public Stream { +public: + BufferStream() = default; + ~BufferStream() override = default; + + bool is_readable() const override; + bool wait_readable() const override; + bool wait_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + time_t duration() const override; + + const std::string &get_buffer() const; + +private: + std::string buffer; + size_t position = 0; +}; + +class compressor { +public: + virtual ~compressor() = default; + + typedef std::function Callback; + virtual bool compress(const char *data, size_t data_length, bool last, + Callback callback) = 0; +}; + +class decompressor { +public: + virtual ~decompressor() = default; + + virtual bool is_valid() const = 0; + + typedef std::function Callback; + virtual bool decompress(const char *data, size_t data_length, + Callback callback) = 0; +}; + +class nocompressor final : public compressor { +public: + ~nocompressor() override = default; + + bool compress(const char *data, size_t data_length, bool /*last*/, + Callback callback) override; +}; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +class gzip_compressor final : public compressor { +public: + gzip_compressor(); + ~gzip_compressor() override; + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + bool is_valid_ = false; + z_stream strm_; +}; + +class gzip_decompressor final : public decompressor { +public: + gzip_decompressor(); + ~gzip_decompressor() override; + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + bool is_valid_ = false; + z_stream strm_; +}; +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +class brotli_compressor final : public compressor { +public: + brotli_compressor(); + ~brotli_compressor(); + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + BrotliEncoderState *state_ = nullptr; +}; + +class brotli_decompressor final : public decompressor { +public: + brotli_decompressor(); + ~brotli_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + BrotliDecoderResult decoder_r; + BrotliDecoderState *decoder_s = nullptr; +}; +#endif + +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +class zstd_compressor : public compressor { +public: + zstd_compressor(); + ~zstd_compressor(); + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + ZSTD_CCtx *ctx_ = nullptr; +}; + +class zstd_decompressor : public decompressor { +public: + zstd_decompressor(); + ~zstd_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + ZSTD_DCtx *ctx_ = nullptr; +}; +#endif + +// NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` +// to store data. The call can set memory on stack for performance. +class stream_line_reader { +public: + stream_line_reader(Stream &strm, char *fixed_buffer, + size_t fixed_buffer_size); + const char *ptr() const; + size_t size() const; + bool end_with_crlf() const; + bool getline(); + +private: + void append(char c); + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string growable_buffer_; +}; + +class mmap { +public: + mmap(const char *path); + ~mmap(); + + bool open(const char *path); + void close(); + + bool is_open() const; + size_t size() const; + const char *data() const; + +private: +#if defined(_WIN32) + HANDLE hFile_ = NULL; + HANDLE hMapping_ = NULL; +#else + int fd_ = -1; +#endif + size_t size_ = 0; + void *addr_ = nullptr; + bool is_open_empty_file = false; +}; + +// NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 +namespace fields { + +inline bool is_token_char(char c) { + return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || + c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; +} + +inline bool is_token(const std::string &s) { + if (s.empty()) { return false; } + for (auto c : s) { + if (!is_token_char(c)) { return false; } + } + return true; +} + +inline bool is_field_name(const std::string &s) { return is_token(s); } + +inline bool is_vchar(char c) { return c >= 33 && c <= 126; } + +inline bool is_obs_text(char c) { return 128 <= static_cast(c); } + +inline bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } + +inline bool is_field_content(const std::string &s) { + if (s.empty()) { return true; } + + if (s.size() == 1) { + return is_field_vchar(s[0]); + } else if (s.size() == 2) { + return is_field_vchar(s[0]) && is_field_vchar(s[1]); + } else { + size_t i = 0; + + if (!is_field_vchar(s[i])) { return false; } + i++; + + while (i < s.size() - 1) { + auto c = s[i++]; + if (c == ' ' || c == '\t' || is_field_vchar(c)) { + } else { + return false; + } + } + + return is_field_vchar(s[i]); + } +} + +inline bool is_field_value(const std::string &s) { return is_field_content(s); } + +} // namespace fields + +} // namespace detail + +// ---------------------------------------------------------------------------- + +/* + * Implementation that will be part of the .cc file if split into .h + .cc. + */ + +namespace detail { + +inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; +} + +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, + int &val) { + if (i >= s.size()) { return false; } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { return false; } + auto v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) { + static const auto charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = static_cast(code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) { + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + auto val = 0; + auto valb = -6; + + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; +} + +inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + if (path[i] == '\0') { + return false; + } else if (path[i] == '\\') { + return false; + } + i++; + } + + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { return false; } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; +} + +inline FileStat::FileStat(const std::string &path) { +#if defined(_WIN32) + auto wpath = u8string_to_wstring(path.c_str()); + ret_ = _wstat(wpath.c_str(), &st_); +#else + ret_ = stat(path.c_str(), &st_); +#endif +} +inline bool FileStat::is_file() const { + return ret_ >= 0 && S_ISREG(st_.st_mode); +} +inline bool FileStat::is_dir() const { + return ret_ >= 0 && S_ISDIR(st_.st_mode); +} + +inline std::string encode_query_param(const std::string &value) { + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + + for (auto c : value) { + if (std::isalnum(static_cast(c)) || c == '-' || c == '_' || + c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' || + c == ')') { + escaped << c; + } else { + escaped << std::uppercase; + escaped << '%' << std::setw(2) + << static_cast(static_cast(c)); + escaped << std::nouppercase; + } + } + + return escaped.str(); +} + +inline std::string encode_url(const std::string &s) { + std::string result; + result.reserve(s.size()); + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': result += "%20"; break; + case '+': result += "%2B"; break; + case '\r': result += "%0D"; break; + case '\n': result += "%0A"; break; + case '\'': result += "%27"; break; + case ',': result += "%2C"; break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': result += "%3B"; break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_url(const std::string &s, + bool convert_plus_to_space) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + auto val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { result.append(buff, len); } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + auto val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; +} + +inline std::string file_extension(const std::string &path) { + std::smatch m; + thread_local auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { return m[1].str(); } + return std::string(); +} + +inline bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } + +inline std::pair trim(const char *b, const char *e, size_t left, + size_t right) { + while (b + left < e && is_space_or_tab(b[left])) { + left++; + } + while (right > 0 && is_space_or_tab(b[right - 1])) { + right--; + } + return std::make_pair(left, right); +} + +inline std::string trim_copy(const std::string &s) { + auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); + return s.substr(r.first, r.second - r.first); +} + +inline std::string trim_double_quotes_copy(const std::string &s) { + if (s.length() >= 2 && s.front() == '"' && s.back() == '"') { + return s.substr(1, s.size() - 2); + } + return s; +} + +inline void +divide(const char *data, std::size_t size, char d, + std::function + fn) { + const auto it = std::find(data, data + size, d); + const auto found = static_cast(it != data + size); + const auto lhs_data = data; + const auto lhs_size = static_cast(it - data); + const auto rhs_data = it + found; + const auto rhs_size = size - lhs_size - found; + + fn(lhs_data, lhs_size, rhs_data, rhs_size); +} + +inline void +divide(const std::string &str, char d, + std::function + fn) { + divide(str.data(), str.size(), d, std::move(fn)); +} + +inline void split(const char *b, const char *e, char d, + std::function fn) { + return split(b, e, d, (std::numeric_limits::max)(), std::move(fn)); +} + +inline void split(const char *b, const char *e, char d, size_t m, + std::function fn) { + size_t i = 0; + size_t beg = 0; + size_t count = 1; + + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d && count < m) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } + beg = i + 1; + count++; + } + i++; + } + + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } + } +} + +inline stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, + size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) {} + +inline const char *stream_line_reader::ptr() const { + if (growable_buffer_.empty()) { + return fixed_buffer_; + } else { + return growable_buffer_.data(); + } +} + +inline size_t stream_line_reader::size() const { + if (growable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return growable_buffer_.size(); + } +} + +inline bool stream_line_reader::end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; +} + +inline bool stream_line_reader::getline() { + fixed_buffer_used_size_ = 0; + growable_buffer_.clear(); + +#ifndef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + char prev_byte = 0; +#endif + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } + + append(byte); + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + if (byte == '\n') { break; } +#else + if (prev_byte == '\r' && byte == '\n') { break; } + prev_byte = byte; +#endif + } + + return true; +} + +inline void stream_line_reader::append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (growable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + growable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + growable_buffer_ += c; + } +} + +inline mmap::mmap(const char *path) { open(path); } + +inline mmap::~mmap() { close(); } + +inline bool mmap::open(const char *path) { + close(); + +#if defined(_WIN32) + auto wpath = u8string_to_wstring(path); + if (wpath.empty()) { return false; } + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + hFile_ = ::CreateFile2(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, + OPEN_EXISTING, NULL); +#else + hFile_ = ::CreateFileW(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, + OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); +#endif + + if (hFile_ == INVALID_HANDLE_VALUE) { return false; } + + LARGE_INTEGER size{}; + if (!::GetFileSizeEx(hFile_, &size)) { return false; } + // If the following line doesn't compile due to QuadPart, update Windows SDK. + // See: + // https://github.com/yhirose/cpp-httplib/issues/1903#issuecomment-2316520721 + if (static_cast(size.QuadPart) > + (std::numeric_limits::max)()) { + // `size_t` might be 32-bits, on 32-bits Windows. + return false; + } + size_ = static_cast(size.QuadPart); + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + hMapping_ = + ::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_, NULL); +#else + hMapping_ = ::CreateFileMappingW(hFile_, NULL, PAGE_READONLY, 0, 0, NULL); +#endif + + // Special treatment for an empty file... + if (hMapping_ == NULL && size_ == 0) { + close(); + is_open_empty_file = true; + return true; + } + + if (hMapping_ == NULL) { + close(); + return false; + } + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + addr_ = ::MapViewOfFileFromApp(hMapping_, FILE_MAP_READ, 0, 0); +#else + addr_ = ::MapViewOfFile(hMapping_, FILE_MAP_READ, 0, 0, 0); +#endif + + if (addr_ == nullptr) { + close(); + return false; + } +#else + fd_ = ::open(path, O_RDONLY); + if (fd_ == -1) { return false; } + + struct stat sb; + if (fstat(fd_, &sb) == -1) { + close(); + return false; + } + size_ = static_cast(sb.st_size); + + addr_ = ::mmap(NULL, size_, PROT_READ, MAP_PRIVATE, fd_, 0); + + // Special treatment for an empty file... + if (addr_ == MAP_FAILED && size_ == 0) { + close(); + is_open_empty_file = true; + return false; + } +#endif + + return true; +} + +inline bool mmap::is_open() const { + return is_open_empty_file ? true : addr_ != nullptr; +} + +inline size_t mmap::size() const { return size_; } + +inline const char *mmap::data() const { + return is_open_empty_file ? "" : static_cast(addr_); +} + +inline void mmap::close() { +#if defined(_WIN32) + if (addr_) { + ::UnmapViewOfFile(addr_); + addr_ = nullptr; + } + + if (hMapping_) { + ::CloseHandle(hMapping_); + hMapping_ = NULL; + } + + if (hFile_ != INVALID_HANDLE_VALUE) { + ::CloseHandle(hFile_); + hFile_ = INVALID_HANDLE_VALUE; + } + + is_open_empty_file = false; +#else + if (addr_ != nullptr) { + munmap(addr_, size_); + addr_ = nullptr; + } + + if (fd_ != -1) { + ::close(fd_); + fd_ = -1; + } +#endif + size_ = 0; +} +inline int close_socket(socket_t sock) { +#ifdef _WIN32 + return closesocket(sock); +#else + return close(sock); +#endif +} + +template inline ssize_t handle_EINTR(T fn) { + ssize_t res = 0; + while (true) { + res = fn(); + if (res < 0 && errno == EINTR) { + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } + break; + } + return res; +} + +inline ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags) { + return handle_EINTR([&]() { + return recv(sock, +#ifdef _WIN32 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, + int flags) { + return handle_EINTR([&]() { + return send(sock, +#ifdef _WIN32 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline int poll_wrapper(struct pollfd *fds, nfds_t nfds, int timeout) { +#ifdef _WIN32 + return ::WSAPoll(fds, nfds, timeout); +#else + return ::poll(fds, nfds, timeout); +#endif +} + +template +inline ssize_t select_impl(socket_t sock, time_t sec, time_t usec) { + struct pollfd pfd; + pfd.fd = sock; + pfd.events = (Read ? POLLIN : POLLOUT); + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return handle_EINTR([&]() { return poll_wrapper(&pfd, 1, timeout); }); +} + +inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { + return select_impl(sock, sec, usec); +} + +inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { + return select_impl(sock, sec, usec); +} + +inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, + time_t usec) { + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + auto poll_res = + handle_EINTR([&]() { return poll_wrapper(&pfd_read, 1, timeout); }); + + if (poll_res == 0) { return Error::ConnectionTimeout; } + + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + + return Error::Connection; +} + +inline bool is_socket_alive(socket_t sock) { + const auto val = detail::select_read(sock, 0, 0); + if (val == 0) { + return true; + } else if (val < 0 && errno == EBADF) { + return false; + } + char buf[1]; + return detail::read_socket(sock, &buf[0], sizeof(buf), MSG_PEEK) > 0; +} + +class SocketStream final : public Stream { +public: + SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec = 0, + std::chrono::time_point start_time = + (std::chrono::steady_clock::time_point::min)()); + ~SocketStream() override; + + bool is_readable() const override; + bool wait_readable() const override; + bool wait_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + time_t duration() const override; + +private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + time_t max_timeout_msec_; + const std::chrono::time_point start_time_; + + std::vector read_buff_; + size_t read_buff_off_ = 0; + size_t read_buff_content_size_ = 0; + + static const size_t read_buff_size_ = 1024l * 4; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLSocketStream final : public Stream { +public: + SSLSocketStream( + socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec = 0, + std::chrono::time_point start_time = + (std::chrono::steady_clock::time_point::min)()); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool wait_readable() const override; + bool wait_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + time_t duration() const override; + +private: + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + time_t max_timeout_msec_; + const std::chrono::time_point start_time_; +}; +#endif + +inline bool keep_alive(const std::atomic &svr_sock, socket_t sock, + time_t keep_alive_timeout_sec) { + using namespace std::chrono; + + const auto interval_usec = + CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND; + + // Avoid expensive `steady_clock::now()` call for the first time + if (select_read(sock, 0, interval_usec) > 0) { return true; } + + const auto start = steady_clock::now() - microseconds{interval_usec}; + const auto timeout = seconds{keep_alive_timeout_sec}; + + while (true) { + if (svr_sock == INVALID_SOCKET) { + break; // Server socket is closed + } + + auto val = select_read(sock, 0, interval_usec); + if (val < 0) { + break; // Ssocket error + } else if (val == 0) { + if (steady_clock::now() - start > timeout) { + break; // Timeout + } + } else { + return true; // Ready for read + } + } + + return false; +} + +template +inline bool +process_server_socket_core(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, T callback) { + assert(keep_alive_max_count > 0); + auto ret = false; + auto count = keep_alive_max_count; + while (count > 0 && keep_alive(svr_sock, sock, keep_alive_timeout_sec)) { + auto close_connection = count == 1; + auto connection_closed = false; + ret = callback(close_connection, connection_closed); + if (!ret || connection_closed) { break; } + count--; + } + return ret; +} + +template +inline bool +process_server_socket(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core( + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +inline bool process_client_socket( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time, + std::function callback) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec, max_timeout_msec, + start_time); + return callback(strm); +} + +inline int shutdown_socket(socket_t sock) { +#ifdef _WIN32 + return shutdown(sock, SD_BOTH); +#else + return shutdown(sock, SHUT_RDWR); +#endif +} + +inline std::string escape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '\0') { + auto ret = s; + ret[0] = '@'; + return ret; + } + return s; +} + +inline std::string +unescape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '@') { + auto ret = s; + ret[0] = '\0'; + return ret; + } + return s; +} + +template +socket_t create_socket(const std::string &host, const std::string &ip, int port, + int address_family, int socket_flags, bool tcp_nodelay, + bool ipv6_v6only, SocketOptions socket_options, + BindOrConnect bind_or_connect) { + // Get address info + const char *node = nullptr; + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_IP; + + if (!ip.empty()) { + node = ip.c_str(); + // Ask getaddrinfo to convert IP in c-string to address + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_NUMERICHOST; + } else { + if (!host.empty()) { node = host.c_str(); } + hints.ai_family = address_family; + hints.ai_flags = socket_flags; + } + + if (hints.ai_family == AF_UNIX) { + const auto addrlen = host.length(); + if (addrlen > sizeof(sockaddr_un::sun_path)) { return INVALID_SOCKET; } + +#ifdef SOCK_CLOEXEC + auto sock = socket(hints.ai_family, hints.ai_socktype | SOCK_CLOEXEC, + hints.ai_protocol); +#else + auto sock = socket(hints.ai_family, hints.ai_socktype, hints.ai_protocol); +#endif + + if (sock != INVALID_SOCKET) { + sockaddr_un addr{}; + addr.sun_family = AF_UNIX; + + auto unescaped_host = unescape_abstract_namespace_unix_domain(host); + std::copy(unescaped_host.begin(), unescaped_host.end(), addr.sun_path); + + hints.ai_addr = reinterpret_cast(&addr); + hints.ai_addrlen = static_cast( + sizeof(addr) - sizeof(addr.sun_path) + addrlen); + +#ifndef SOCK_CLOEXEC +#ifndef _WIN32 + fcntl(sock, F_SETFD, FD_CLOEXEC); +#endif +#endif + + if (socket_options) { socket_options(sock); } + +#ifdef _WIN32 + // Setting SO_REUSEADDR seems not to work well with AF_UNIX on windows, so + // remove the option. + detail::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0); +#endif + + bool dummy; + if (!bind_or_connect(sock, hints, dummy)) { + close_socket(sock); + sock = INVALID_SOCKET; + } + } + return sock; + } + + auto service = std::to_string(port); + + if (getaddrinfo(node, service.c_str(), &hints, &result)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return INVALID_SOCKET; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN32 + auto sock = + WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, nullptr, 0, + WSA_FLAG_NO_HANDLE_INHERIT | WSA_FLAG_OVERLAPPED); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } +#else + +#ifdef SOCK_CLOEXEC + auto sock = + socket(rp->ai_family, rp->ai_socktype | SOCK_CLOEXEC, rp->ai_protocol); +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + +#endif + if (sock == INVALID_SOCKET) { continue; } + +#if !defined _WIN32 && !defined SOCK_CLOEXEC + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { + close_socket(sock); + continue; + } +#endif + + if (tcp_nodelay) { set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1); } + + if (rp->ai_family == AF_INET6) { + set_socket_opt(sock, IPPROTO_IPV6, IPV6_V6ONLY, ipv6_v6only ? 1 : 0); + } + + if (socket_options) { socket_options(sock); } + + // bind or connect + auto quit = false; + if (bind_or_connect(sock, *rp, quit)) { return sock; } + + close_socket(sock); + + if (quit) { break; } + } + + return INVALID_SOCKET; +} + +inline void set_nonblocking(socket_t sock, bool nonblocking) { +#ifdef _WIN32 + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); +#else + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); +#endif +} + +inline bool is_connection_error() { +#ifdef _WIN32 + return WSAGetLastError() != WSAEWOULDBLOCK; +#else + return errno != EINPROGRESS; +#endif +} + +inline bool bind_ip_address(socket_t sock, const std::string &host) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(host.c_str(), "0", &hints, &result)) { return false; } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } + } + + return ret; +} + +#if !defined _WIN32 && !defined ANDROID && !defined _AIX && !defined __MVS__ +#define USE_IF2IP +#endif + +#ifdef USE_IF2IP +inline std::string if2ip(int address_family, const std::string &ifn) { + struct ifaddrs *ifap; + getifaddrs(&ifap); + auto se = detail::scope_exit([&] { freeifaddrs(ifap); }); + + std::string addr_candidate; + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name && + (AF_UNSPEC == address_family || + ifa->ifa_addr->sa_family == address_family)) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + return std::string(buf, INET_ADDRSTRLEN); + } + } else if (ifa->ifa_addr->sa_family == AF_INET6) { + auto sa = reinterpret_cast(ifa->ifa_addr); + if (!IN6_IS_ADDR_LINKLOCAL(&sa->sin6_addr)) { + char buf[INET6_ADDRSTRLEN] = {}; + if (inet_ntop(AF_INET6, &sa->sin6_addr, buf, INET6_ADDRSTRLEN)) { + // equivalent to mac's IN6_IS_ADDR_UNIQUE_LOCAL + auto s6_addr_head = sa->sin6_addr.s6_addr[0]; + if (s6_addr_head == 0xfc || s6_addr_head == 0xfd) { + addr_candidate = std::string(buf, INET6_ADDRSTRLEN); + } else { + return std::string(buf, INET6_ADDRSTRLEN); + } + } + } + } + } + } + return addr_candidate; +} +#endif + +inline socket_t create_client_socket( + const std::string &host, const std::string &ip, int port, + int address_family, bool tcp_nodelay, bool ipv6_v6only, + SocketOptions socket_options, time_t connection_timeout_sec, + time_t connection_timeout_usec, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, const std::string &intf, Error &error) { + auto sock = create_socket( + host, ip, port, address_family, 0, tcp_nodelay, ipv6_v6only, + std::move(socket_options), + [&](socket_t sock2, struct addrinfo &ai, bool &quit) -> bool { + if (!intf.empty()) { +#ifdef USE_IF2IP + auto ip_from_if = if2ip(address_family, intf); + if (ip_from_if.empty()) { ip_from_if = intf; } + if (!bind_ip_address(sock2, ip_from_if)) { + error = Error::BindIPAddress; + return false; + } +#endif + } + + set_nonblocking(sock2, true); + + auto ret = + ::connect(sock2, ai.ai_addr, static_cast(ai.ai_addrlen)); + + if (ret < 0) { + if (is_connection_error()) { + error = Error::Connection; + return false; + } + error = wait_until_socket_is_ready(sock2, connection_timeout_sec, + connection_timeout_usec); + if (error != Error::Success) { + if (error == Error::ConnectionTimeout) { quit = true; } + return false; + } + } + + set_nonblocking(sock2, false); + set_socket_opt_time(sock2, SOL_SOCKET, SO_RCVTIMEO, read_timeout_sec, + read_timeout_usec); + set_socket_opt_time(sock2, SOL_SOCKET, SO_SNDTIMEO, write_timeout_sec, + write_timeout_usec); + + error = Error::Success; + return true; + }); + + if (sock != INVALID_SOCKET) { + error = Error::Success; + } else { + if (error == Error::Success) { error = Error::Connection; } + } + + return sock; +} + +inline bool get_ip_and_port(const struct sockaddr_storage &addr, + socklen_t addr_len, std::string &ip, int &port) { + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = + ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return false; + } + + std::array ipstr{}; + if (getnameinfo(reinterpret_cast(&addr), addr_len, + ipstr.data(), static_cast(ipstr.size()), nullptr, + 0, NI_NUMERICHOST)) { + return false; + } + + ip = ipstr.data(); + return true; +} + +inline void get_local_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (!getsockname(sock, reinterpret_cast(&addr), + &addr_len)) { + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast(&addr), + &addr_len)) { +#ifndef _WIN32 + if (addr.ss_family == AF_UNIX) { +#if defined(__linux__) + struct ucred ucred; + socklen_t len = sizeof(ucred); + if (getsockopt(sock, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == 0) { + port = ucred.pid; + } +#elif defined(SOL_LOCAL) && defined(SO_PEERPID) // __APPLE__ + pid_t pid; + socklen_t len = sizeof(pid); + if (getsockopt(sock, SOL_LOCAL, SO_PEERPID, &pid, &len) == 0) { + port = pid; + } +#endif + return; + } +#endif + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline constexpr unsigned int str2tag_core(const char *s, size_t l, + unsigned int h) { + return (l == 0) + ? h + : str2tag_core( + s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no overflow happens + (((std::numeric_limits::max)() >> 6) & + h * 33) ^ + static_cast(*s)); +} + +inline unsigned int str2tag(const std::string &s) { + return str2tag_core(s.data(), s.size(), 0); +} + +namespace udl { + +inline constexpr unsigned int operator""_t(const char *s, size_t l) { + return str2tag_core(s, l, 0); +} + +} // namespace udl + +inline std::string +find_content_type(const std::string &path, + const std::map &user_data, + const std::string &default_content_type) { + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { return it->second; } + + using udl::operator""_t; + + switch (str2tag(ext)) { + default: return default_content_type; + + case "css"_t: return "text/css"; + case "csv"_t: return "text/csv"; + case "htm"_t: + case "html"_t: return "text/html"; + case "js"_t: + case "mjs"_t: return "text/javascript"; + case "txt"_t: return "text/plain"; + case "vtt"_t: return "text/vtt"; + + case "apng"_t: return "image/apng"; + case "avif"_t: return "image/avif"; + case "bmp"_t: return "image/bmp"; + case "gif"_t: return "image/gif"; + case "png"_t: return "image/png"; + case "svg"_t: return "image/svg+xml"; + case "webp"_t: return "image/webp"; + case "ico"_t: return "image/x-icon"; + case "tif"_t: return "image/tiff"; + case "tiff"_t: return "image/tiff"; + case "jpg"_t: + case "jpeg"_t: return "image/jpeg"; + + case "mp4"_t: return "video/mp4"; + case "mpeg"_t: return "video/mpeg"; + case "webm"_t: return "video/webm"; + + case "mp3"_t: return "audio/mp3"; + case "mpga"_t: return "audio/mpeg"; + case "weba"_t: return "audio/webm"; + case "wav"_t: return "audio/wave"; + + case "otf"_t: return "font/otf"; + case "ttf"_t: return "font/ttf"; + case "woff"_t: return "font/woff"; + case "woff2"_t: return "font/woff2"; + + case "7z"_t: return "application/x-7z-compressed"; + case "atom"_t: return "application/atom+xml"; + case "pdf"_t: return "application/pdf"; + case "json"_t: return "application/json"; + case "rss"_t: return "application/rss+xml"; + case "tar"_t: return "application/x-tar"; + case "xht"_t: + case "xhtml"_t: return "application/xhtml+xml"; + case "xslt"_t: return "application/xslt+xml"; + case "xml"_t: return "application/xml"; + case "gz"_t: return "application/gzip"; + case "zip"_t: return "application/zip"; + case "wasm"_t: return "application/wasm"; + } +} + +inline bool can_compress_content_type(const std::string &content_type) { + using udl::operator""_t; + + auto tag = str2tag(content_type); + + switch (tag) { + case "image/svg+xml"_t: + case "application/javascript"_t: + case "application/json"_t: + case "application/xml"_t: + case "application/protobuf"_t: + case "application/xhtml+xml"_t: return true; + + case "text/event-stream"_t: return false; + + default: return !content_type.rfind("text/", 0); + } +} + +inline EncodingType encoding_type(const Request &req, const Response &res) { + auto ret = + detail::can_compress_content_type(res.get_header_value("Content-Type")); + if (!ret) { return EncodingType::None; } + + const auto &s = req.get_header_value("Accept-Encoding"); + (void)(s); + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + // TODO: 'Accept-Encoding' has br, not br;q=0 + ret = s.find("br") != std::string::npos; + if (ret) { return EncodingType::Brotli; } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + ret = s.find("gzip") != std::string::npos; + if (ret) { return EncodingType::Gzip; } +#endif + +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + // TODO: 'Accept-Encoding' has zstd, not zstd;q=0 + ret = s.find("zstd") != std::string::npos; + if (ret) { return EncodingType::Zstd; } +#endif + + return EncodingType::None; +} + +inline bool nocompressor::compress(const char *data, size_t data_length, + bool /*last*/, Callback callback) { + if (!data_length) { return true; } + return callback(data, data_length); +} + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +inline gzip_compressor::gzip_compressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY) == Z_OK; +} + +inline gzip_compressor::~gzip_compressor() { deflateEnd(&strm_); } + +inline bool gzip_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + assert(is_valid_); + + do { + constexpr size_t max_avail_in = + (std::numeric_limits::max)(); + + strm_.avail_in = static_cast( + (std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + auto flush = (last && data_length == 0) ? Z_FINISH : Z_NO_FLUSH; + auto ret = Z_OK; + + std::array buff{}; + do { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = deflate(&strm_, flush); + if (ret == Z_STREAM_ERROR) { return false; } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } while (strm_.avail_out == 0); + + assert((flush == Z_FINISH && ret == Z_STREAM_END) || + (flush == Z_NO_FLUSH && ret == Z_OK)); + assert(strm_.avail_in == 0); + } while (data_length > 0); + + return true; +} + +inline gzip_decompressor::gzip_decompressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; +} + +inline gzip_decompressor::~gzip_decompressor() { inflateEnd(&strm_); } + +inline bool gzip_decompressor::is_valid() const { return is_valid_; } + +inline bool gzip_decompressor::decompress(const char *data, size_t data_length, + Callback callback) { + assert(is_valid_); + + auto ret = Z_OK; + + do { + constexpr size_t max_avail_in = + (std::numeric_limits::max)(); + + strm_.avail_in = static_cast( + (std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + std::array buff{}; + while (strm_.avail_in > 0 && ret == Z_OK) { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = inflate(&strm_, Z_NO_FLUSH); + + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: inflateEnd(&strm_); return false; + } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } + + if (ret != Z_OK && ret != Z_STREAM_END) { return false; } + + } while (data_length > 0); + + return true; +} +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +inline brotli_compressor::brotli_compressor() { + state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); +} + +inline brotli_compressor::~brotli_compressor() { + BrotliEncoderDestroyInstance(state_); +} + +inline bool brotli_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + std::array buff{}; + + auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; + auto available_in = data_length; + auto next_in = reinterpret_cast(data); + + for (;;) { + if (last) { + if (BrotliEncoderIsFinished(state_)) { break; } + } else { + if (!available_in) { break; } + } + + auto available_out = buff.size(); + auto next_out = buff.data(); + + if (!BrotliEncoderCompressStream(state_, operation, &available_in, &next_in, + &available_out, &next_out, nullptr)) { + return false; + } + + auto output_bytes = buff.size() - available_out; + if (output_bytes) { + callback(reinterpret_cast(buff.data()), output_bytes); + } + } + + return true; +} + +inline brotli_decompressor::brotli_decompressor() { + decoder_s = BrotliDecoderCreateInstance(0, 0, 0); + decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT + : BROTLI_DECODER_RESULT_ERROR; +} + +inline brotli_decompressor::~brotli_decompressor() { + if (decoder_s) { BrotliDecoderDestroyInstance(decoder_s); } +} + +inline bool brotli_decompressor::is_valid() const { return decoder_s; } + +inline bool brotli_decompressor::decompress(const char *data, + size_t data_length, + Callback callback) { + if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return 0; + } + + auto next_in = reinterpret_cast(data); + size_t avail_in = data_length; + size_t total_out; + + decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; + + std::array buff{}; + while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + char *next_out = buff.data(); + size_t avail_out = buff.size(); + + decoder_r = BrotliDecoderDecompressStream( + decoder_s, &avail_in, &next_in, &avail_out, + reinterpret_cast(&next_out), &total_out); + + if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { return false; } + + if (!callback(buff.data(), buff.size() - avail_out)) { return false; } + } + + return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; +} +#endif + +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +inline zstd_compressor::zstd_compressor() { + ctx_ = ZSTD_createCCtx(); + ZSTD_CCtx_setParameter(ctx_, ZSTD_c_compressionLevel, ZSTD_fast); +} + +inline zstd_compressor::~zstd_compressor() { ZSTD_freeCCtx(ctx_); } + +inline bool zstd_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + std::array buff{}; + + ZSTD_EndDirective mode = last ? ZSTD_e_end : ZSTD_e_continue; + ZSTD_inBuffer input = {data, data_length, 0}; + + bool finished; + do { + ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0}; + size_t const remaining = ZSTD_compressStream2(ctx_, &output, &input, mode); + + if (ZSTD_isError(remaining)) { return false; } + + if (!callback(buff.data(), output.pos)) { return false; } + + finished = last ? (remaining == 0) : (input.pos == input.size); + + } while (!finished); + + return true; +} + +inline zstd_decompressor::zstd_decompressor() { ctx_ = ZSTD_createDCtx(); } + +inline zstd_decompressor::~zstd_decompressor() { ZSTD_freeDCtx(ctx_); } + +inline bool zstd_decompressor::is_valid() const { return ctx_ != nullptr; } + +inline bool zstd_decompressor::decompress(const char *data, size_t data_length, + Callback callback) { + std::array buff{}; + ZSTD_inBuffer input = {data, data_length, 0}; + + while (input.pos < input.size) { + ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0}; + size_t const remaining = ZSTD_decompressStream(ctx_, &output, &input); + + if (ZSTD_isError(remaining)) { return false; } + + if (!callback(buff.data(), output.pos)) { return false; } + } + + return true; +} +#endif + +inline bool has_header(const Headers &headers, const std::string &key) { + return headers.find(key) != headers.end(); +} + +inline const char *get_header_value(const Headers &headers, + const std::string &key, const char *def, + size_t id) { + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second.c_str(); } + return def; +} + +template +inline bool parse_header(const char *beg, const char *end, T fn) { + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } + + auto p = beg; + while (p < end && *p != ':') { + p++; + } + + auto name = std::string(beg, p); + if (!detail::fields::is_field_name(name)) { return false; } + + if (p == end) { return false; } + + auto key_end = p; + + if (*p++ != ':') { return false; } + + while (p < end && is_space_or_tab(*p)) { + p++; + } + + if (p <= end) { + auto key_len = key_end - beg; + if (!key_len) { return false; } + + auto key = std::string(beg, key_end); + auto val = std::string(p, end); + + if (!detail::fields::is_field_value(val)) { return false; } + + if (case_ignore::equal(key, "Location") || + case_ignore::equal(key, "Referer")) { + fn(key, val); + } else { + fn(key, decode_url(val, false)); + } + + return true; + } + + return false; +} + +inline bool read_headers(Stream &strm, Headers &headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + for (;;) { + if (!line_reader.getline()) { return false; } + + // Check if the line ends with CRLF. + auto line_terminator_len = 2; + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { break; } + } else { +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + // Blank line indicates end of headers. + if (line_reader.size() == 1) { break; } + line_terminator_len = 1; +#else + continue; // Skip invalid line. +#endif + } + + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + + // Exclude line terminator + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + if (!parse_header(line_reader.ptr(), end, + [&](const std::string &key, const std::string &val) { + headers.emplace(key, val); + })) { + return false; + } + } + + return true; +} + +inline bool read_content_with_length(Stream &strm, uint64_t len, + Progress progress, + ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return false; } + + if (!out(buf, static_cast(n), r, len)) { return false; } + r += static_cast(n); + + if (progress) { + if (!progress(r, len)) { return false; } + } + } + + return true; +} + +inline void skip_content_with_length(Stream &strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return; } + r += static_cast(n); + } +} + +inline bool read_content_without_length(Stream &strm, + ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n == 0) { return true; } + if (n < 0) { return false; } + + if (!out(buf, static_cast(n), r, 0)) { return false; } + r += static_cast(n); + } + + return true; +} + +template +inline bool read_content_chunked(Stream &strm, T &x, + ContentReceiverWithProgress out) { + const auto bufsiz = 16; + char buf[bufsiz]; + + stream_line_reader line_reader(strm, buf, bufsiz); + + if (!line_reader.getline()) { return false; } + + unsigned long chunk_len; + while (true) { + char *end_ptr; + + chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + + if (end_ptr == line_reader.ptr()) { return false; } + if (chunk_len == ULONG_MAX) { return false; } + + if (chunk_len == 0) { break; } + + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } + + if (!line_reader.getline()) { return false; } + + if (strcmp(line_reader.ptr(), "\r\n") != 0) { return false; } + + if (!line_reader.getline()) { return false; } + } + + assert(chunk_len == 0); + + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked + // transfer coding is complete when a chunk with a chunk-size of zero is + // received, possibly followed by a trailer section, and finally terminated by + // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 + // + // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section + // does't care for the existence of the final CRLF. In other words, it seems + // to be ok whether the final CRLF exists or not in the chunked data. + // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 + // + // According to the reference code in RFC 9112, cpp-httplib now allows + // chunked transfer coding data without the final CRLF. + if (!line_reader.getline()) { return true; } + + while (strcmp(line_reader.ptr(), "\r\n") != 0) { + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + + // Exclude line terminator + constexpr auto line_terminator_len = 2; + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + parse_header(line_reader.ptr(), end, + [&](const std::string &key, const std::string &val) { + x.headers.emplace(key, val); + }); + + if (!line_reader.getline()) { return false; } + } + + return true; +} + +inline bool is_chunked_transfer_encoding(const Headers &headers) { + return case_ignore::equal( + get_header_value(headers, "Transfer-Encoding", "", 0), "chunked"); +} + +template +bool prepare_content_receiver(T &x, int &status, + ContentReceiverWithProgress receiver, + bool decompress, U callback) { + if (decompress) { + std::string encoding = x.get_header_value("Content-Encoding"); + std::unique_ptr decompressor; + + if (encoding == "gzip" || encoding == "deflate") { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } else if (encoding.find("br") != std::string::npos) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } else if (encoding == "zstd") { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } + + if (decompressor) { + if (decompressor->is_valid()) { + ContentReceiverWithProgress out = [&](const char *buf, size_t n, + uint64_t off, uint64_t len) { + return decompressor->decompress(buf, n, + [&](const char *buf2, size_t n2) { + return receiver(buf2, n2, off, len); + }); + }; + return callback(std::move(out)); + } else { + status = StatusCode::InternalServerError_500; + return false; + } + } + } + + ContentReceiverWithProgress out = [&](const char *buf, size_t n, uint64_t off, + uint64_t len) { + return receiver(buf, n, off, len); + }; + return callback(std::move(out)); +} + +template +bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, + Progress progress, ContentReceiverWithProgress receiver, + bool decompress) { + return prepare_content_receiver( + x, status, std::move(receiver), decompress, + [&](const ContentReceiverWithProgress &out) { + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, x, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto is_invalid_value = false; + auto len = get_header_value_u64( + x.headers, "Content-Length", + (std::numeric_limits::max)(), 0, is_invalid_value); + + if (is_invalid_value) { + ret = false; + } else if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, std::move(progress), out); + } + } + + if (!ret) { + status = exceed_payload_max_length ? StatusCode::PayloadTooLarge_413 + : StatusCode::BadRequest_400; + } + return ret; + }); +} + +inline ssize_t write_request_line(Stream &strm, const std::string &method, + const std::string &path) { + std::string s = method; + s += " "; + s += path; + s += " HTTP/1.1\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_response_line(Stream &strm, int status) { + std::string s = "HTTP/1.1 "; + s += std::to_string(status); + s += " "; + s += httplib::status_message(status); + s += "\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_headers(Stream &strm, const Headers &headers) { + ssize_t write_len = 0; + for (const auto &x : headers) { + std::string s; + s = x.first; + s += ": "; + s += x.second; + s += "\r\n"; + + auto len = strm.write(s.data(), s.size()); + if (len < 0) { return len; } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { return len; } + write_len += len; + return write_len; +} + +inline bool write_data(Stream &strm, const char *d, size_t l) { + size_t offset = 0; + while (offset < l) { + auto length = strm.write(d + offset, l - offset); + if (length < 0) { return false; } + offset += static_cast(length); + } + return true; +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, + size_t offset, size_t length, T is_shutting_down, + Error &error) { + size_t end_offset = offset + length; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + if (write_data(strm, d, l)) { + offset += l; + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + + while (offset < end_offset && !is_shutting_down()) { + if (!strm.wait_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, end_offset - offset, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, + size_t offset, size_t length, + const T &is_shutting_down) { + auto error = Error::Success; + return write_content(strm, content_provider, offset, length, is_shutting_down, + error); +} + +template +inline bool +write_content_without_length(Stream &strm, + const ContentProvider &content_provider, + const T &is_shutting_down) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + offset += l; + if (!write_data(strm, d, l)) { ok = false; } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + + data_sink.done = [&](void) { data_available = false; }; + + while (data_available && !is_shutting_down()) { + if (!strm.wait_writable()) { + return false; + } else if (!content_provider(offset, 0, data_sink)) { + return false; + } else if (!ok) { + return false; + } + } + return true; +} + +template +inline bool +write_content_chunked(Stream &strm, const ContentProvider &content_provider, + const T &is_shutting_down, U &compressor, Error &error) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + data_available = l > 0; + offset += l; + + std::string payload; + if (compressor.compress(d, l, false, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = + from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; } + } + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + + auto done_with_trailer = [&](const Headers *trailer) { + if (!ok) { return; } + + data_available = false; + + std::string payload; + if (!compressor.compress(nullptr, 0, true, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!write_data(strm, chunk.data(), chunk.size())) { + ok = false; + return; + } + } + + constexpr const char done_marker[] = "0\r\n"; + if (!write_data(strm, done_marker, str_len(done_marker))) { ok = false; } + + // Trailer + if (trailer) { + for (const auto &kv : *trailer) { + std::string field_line = kv.first + ": " + kv.second + "\r\n"; + if (!write_data(strm, field_line.data(), field_line.size())) { + ok = false; + } + } + } + + constexpr const char crlf[] = "\r\n"; + if (!write_data(strm, crlf, str_len(crlf))) { ok = false; } + }; + + data_sink.done = [&](void) { done_with_trailer(nullptr); }; + + data_sink.done_with_trailer = [&](const Headers &trailer) { + done_with_trailer(&trailer); + }; + + while (data_available && !is_shutting_down()) { + if (!strm.wait_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, 0, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content_chunked(Stream &strm, + const ContentProvider &content_provider, + const T &is_shutting_down, U &compressor) { + auto error = Error::Success; + return write_content_chunked(strm, content_provider, is_shutting_down, + compressor, error); +} + +template +inline bool redirect(T &cli, Request &req, Response &res, + const std::string &path, const std::string &location, + Error &error) { + Request new_req = req; + new_req.path = path; + new_req.redirect_count_ -= 1; + + if (res.status == StatusCode::SeeOther_303 && + (req.method != "GET" && req.method != "HEAD")) { + new_req.method = "GET"; + new_req.body.clear(); + new_req.headers.clear(); + } + + Response new_res; + + auto ret = cli.send(new_req, new_res, error); + if (ret) { + req = new_req; + res = new_res; + + if (res.location.empty()) { res.location = location; } + } + return ret; +} + +inline std::string params_to_query_str(const Params ¶ms) { + std::string query; + + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += encode_query_param(it->second); + } + return query; +} + +inline void parse_query_text(const char *data, std::size_t size, + Params ¶ms) { + std::set cache; + split(data, data + size, '&', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { return; } + cache.insert(std::move(kv)); + + std::string key; + std::string val; + divide(b, static_cast(e - b), '=', + [&](const char *lhs_data, std::size_t lhs_size, const char *rhs_data, + std::size_t rhs_size) { + key.assign(lhs_data, lhs_size); + val.assign(rhs_data, rhs_size); + }); + + if (!key.empty()) { + params.emplace(decode_url(key, true), decode_url(val, true)); + } + }); +} + +inline void parse_query_text(const std::string &s, Params ¶ms) { + parse_query_text(s.data(), s.size(), params); +} + +inline bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary) { + auto boundary_keyword = "boundary="; + auto pos = content_type.find(boundary_keyword); + if (pos == std::string::npos) { return false; } + auto end = content_type.find(';', pos); + auto beg = pos + strlen(boundary_keyword); + boundary = trim_double_quotes_copy(content_type.substr(beg, end - beg)); + return !boundary.empty(); +} + +inline void parse_disposition_params(const std::string &s, Params ¶ms) { + std::set cache; + split(s.data(), s.data() + s.size(), ';', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { return; } + cache.insert(kv); + + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); + + if (!key.empty()) { + params.emplace(trim_double_quotes_copy((key)), + trim_double_quotes_copy((val))); + } + }); +} + +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +inline bool parse_range_header(const std::string &s, Ranges &ranges) { +#else +inline bool parse_range_header(const std::string &s, Ranges &ranges) try { +#endif + auto is_valid = [](const std::string &str) { + return std::all_of(str.cbegin(), str.cend(), + [](unsigned char c) { return std::isdigit(c); }); + }; + + if (s.size() > 7 && s.compare(0, 6, "bytes=") == 0) { + const auto pos = static_cast(6); + const auto len = static_cast(s.size() - 6); + auto all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) { return; } + + const auto it = std::find(b, e, '-'); + if (it == e) { + all_valid_ranges = false; + return; + } + + const auto lhs = std::string(b, it); + const auto rhs = std::string(it + 1, e); + if (!is_valid(lhs) || !is_valid(rhs)) { + all_valid_ranges = false; + return; + } + + const auto first = + static_cast(lhs.empty() ? -1 : std::stoll(lhs)); + const auto last = + static_cast(rhs.empty() ? -1 : std::stoll(rhs)); + if ((first == -1 && last == -1) || + (first != -1 && last != -1 && first > last)) { + all_valid_ranges = false; + return; + } + + ranges.emplace_back(first, last); + }); + return all_valid_ranges && !ranges.empty(); + } + return false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +} +#else +} catch (...) { return false; } +#endif + +class MultipartFormDataParser { +public: + MultipartFormDataParser() = default; + + void set_boundary(std::string &&boundary) { + boundary_ = boundary; + dash_boundary_crlf_ = dash_ + boundary_ + crlf_; + crlf_dash_boundary_ = crlf_ + dash_ + boundary_; + } + + bool is_valid() const { return is_valid_; } + + bool parse(const char *buf, size_t n, const ContentReceiver &content_callback, + const MultipartContentHeader &header_callback) { + + buf_append(buf, n); + + while (buf_size() > 0) { + switch (state_) { + case 0: { // Initial boundary + buf_erase(buf_find(dash_boundary_crlf_)); + if (dash_boundary_crlf_.size() > buf_size()) { return true; } + if (!buf_start_with(dash_boundary_crlf_)) { return false; } + buf_erase(dash_boundary_crlf_.size()); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_find(crlf_); + if (pos > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + while (pos < buf_size()) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + return false; + } + buf_erase(crlf_.size()); + state_ = 3; + break; + } + + const auto header = buf_head(pos); + + if (!parse_header(header.data(), header.data() + header.size(), + [&](const std::string &, const std::string &) {})) { + is_valid_ = false; + return false; + } + + constexpr const char header_content_type[] = "Content-Type:"; + + if (start_with_case_ignore(header, header_content_type)) { + file_.content_type = + trim_copy(header.substr(str_len(header_content_type))); + } else { + thread_local const std::regex re_content_disposition( + R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~", + std::regex_constants::icase); + + std::smatch m; + if (std::regex_match(header, m, re_content_disposition)) { + Params params; + parse_disposition_params(m[1], params); + + auto it = params.find("name"); + if (it != params.end()) { + file_.name = it->second; + } else { + is_valid_ = false; + return false; + } + + it = params.find("filename"); + if (it != params.end()) { file_.filename = it->second; } + + it = params.find("filename*"); + if (it != params.end()) { + // Only allow UTF-8 encoding... + thread_local const std::regex re_rfc5987_encoding( + R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase); + + std::smatch m2; + if (std::regex_match(it->second, m2, re_rfc5987_encoding)) { + file_.filename = decode_url(m2[1], false); // override... + } else { + is_valid_ = false; + return false; + } + } + } + } + buf_erase(pos + crlf_.size()); + pos = buf_find(crlf_); + } + if (state_ != 3) { return true; } + break; + } + case 3: { // Body + if (crlf_dash_boundary_.size() > buf_size()) { return true; } + auto pos = buf_find(crlf_dash_boundary_); + if (pos < buf_size()) { + if (!content_callback(buf_data(), pos)) { + is_valid_ = false; + return false; + } + buf_erase(pos + crlf_dash_boundary_.size()); + state_ = 4; + } else { + auto len = buf_size() - crlf_dash_boundary_.size(); + if (len > 0) { + if (!content_callback(buf_data(), len)) { + is_valid_ = false; + return false; + } + buf_erase(len); + } + return true; + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_size()) { return true; } + if (buf_start_with(crlf_)) { + buf_erase(crlf_.size()); + state_ = 1; + } else { + if (dash_.size() > buf_size()) { return true; } + if (buf_start_with(dash_)) { + buf_erase(dash_.size()); + is_valid_ = true; + buf_erase(buf_size()); // Remove epilogue + } else { + return true; + } + } + break; + } + } + } + + return true; + } + +private: + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + bool start_with_case_ignore(const std::string &a, const char *b) const { + const auto b_len = strlen(b); + if (a.size() < b_len) { return false; } + for (size_t i = 0; i < b_len; i++) { + if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) { + return false; + } + } + return true; + } + + const std::string dash_ = "--"; + const std::string crlf_ = "\r\n"; + std::string boundary_; + std::string dash_boundary_crlf_; + std::string crlf_dash_boundary_; + + size_t state_ = 0; + bool is_valid_ = false; + MultipartFormData file_; + + // Buffer + bool start_with(const std::string &a, size_t spos, size_t epos, + const std::string &b) const { + if (epos - spos < b.size()) { return false; } + for (size_t i = 0; i < b.size(); i++) { + if (a[i + spos] != b[i]) { return false; } + } + return true; + } + + size_t buf_size() const { return buf_epos_ - buf_spos_; } + + const char *buf_data() const { return &buf_[buf_spos_]; } + + std::string buf_head(size_t l) const { return buf_.substr(buf_spos_, l); } + + bool buf_start_with(const std::string &s) const { + return start_with(buf_, buf_spos_, buf_epos_, s); + } + + size_t buf_find(const std::string &s) const { + auto c = s.front(); + + size_t off = buf_spos_; + while (off < buf_epos_) { + auto pos = off; + while (true) { + if (pos == buf_epos_) { return buf_size(); } + if (buf_[pos] == c) { break; } + pos++; + } + + auto remaining_size = buf_epos_ - pos; + if (s.size() > remaining_size) { return buf_size(); } + + if (start_with(buf_, pos, buf_epos_, s)) { return pos - buf_spos_; } + + off = pos + 1; + } + + return buf_size(); + } + + void buf_append(const char *data, size_t n) { + auto remaining_size = buf_size(); + if (remaining_size > 0 && buf_spos_ > 0) { + for (size_t i = 0; i < remaining_size; i++) { + buf_[i] = buf_[buf_spos_ + i]; + } + } + buf_spos_ = 0; + buf_epos_ = remaining_size; + + if (remaining_size + n > buf_.size()) { buf_.resize(remaining_size + n); } + + for (size_t i = 0; i < n; i++) { + buf_[buf_epos_ + i] = data[i]; + } + buf_epos_ += n; + } + + void buf_erase(size_t size) { buf_spos_ += size; } + + std::string buf_; + size_t buf_spos_ = 0; + size_t buf_epos_ = 0; +}; + +inline std::string random_string(size_t length) { + constexpr const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + thread_local auto engine([]() { + // std::random_device might actually be deterministic on some + // platforms, but due to lack of support in the c++ standard library, + // doing better requires either some ugly hacks or breaking portability. + std::random_device seed_gen; + // Request 128 bits of entropy for initialization + std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), seed_gen()}; + return std::mt19937(seed_sequence); + }()); + + std::string result; + for (size_t i = 0; i < length; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + return result; +} + +inline std::string make_multipart_data_boundary() { + return "--cpp-httplib-multipart-data-" + detail::random_string(16); +} + +inline bool is_multipart_boundary_chars_valid(const std::string &boundary) { + auto valid = true; + for (size_t i = 0; i < boundary.size(); i++) { + auto c = boundary[i]; + if (!std::isalnum(c) && c != '-' && c != '_') { + valid = false; + break; + } + } + return valid; +} + +template +inline std::string +serialize_multipart_formdata_item_begin(const T &item, + const std::string &boundary) { + std::string body = "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + + return body; +} + +inline std::string serialize_multipart_formdata_item_end() { return "\r\n"; } + +inline std::string +serialize_multipart_formdata_finish(const std::string &boundary) { + return "--" + boundary + "--\r\n"; +} + +inline std::string +serialize_multipart_formdata_get_content_type(const std::string &boundary) { + return "multipart/form-data; boundary=" + boundary; +} + +inline std::string +serialize_multipart_formdata(const MultipartFormDataItems &items, + const std::string &boundary, bool finish = true) { + std::string body; + + for (const auto &item : items) { + body += serialize_multipart_formdata_item_begin(item, boundary); + body += item.content + serialize_multipart_formdata_item_end(); + } + + if (finish) { body += serialize_multipart_formdata_finish(boundary); } + + return body; +} + +inline bool range_error(Request &req, Response &res) { + if (!req.ranges.empty() && 200 <= res.status && res.status < 300) { + ssize_t content_len = static_cast( + res.content_length_ ? res.content_length_ : res.body.size()); + + ssize_t prev_first_pos = -1; + ssize_t prev_last_pos = -1; + size_t overwrapping_count = 0; + + // NOTE: The following Range check is based on '14.2. Range' in RFC 9110 + // 'HTTP Semantics' to avoid potential denial-of-service attacks. + // https://www.rfc-editor.org/rfc/rfc9110#section-14.2 + + // Too many ranges + if (req.ranges.size() > CPPHTTPLIB_RANGE_MAX_COUNT) { return true; } + + for (auto &r : req.ranges) { + auto &first_pos = r.first; + auto &last_pos = r.second; + + if (first_pos == -1 && last_pos == -1) { + first_pos = 0; + last_pos = content_len; + } + + if (first_pos == -1) { + first_pos = content_len - last_pos; + last_pos = content_len - 1; + } + + // NOTE: RFC-9110 '14.1.2. Byte Ranges': + // A client can limit the number of bytes requested without knowing the + // size of the selected representation. If the last-pos value is absent, + // or if the value is greater than or equal to the current length of the + // representation data, the byte range is interpreted as the remainder of + // the representation (i.e., the server replaces the value of last-pos + // with a value that is one less than the current length of the selected + // representation). + // https://www.rfc-editor.org/rfc/rfc9110.html#section-14.1.2-6 + if (last_pos == -1 || last_pos >= content_len) { + last_pos = content_len - 1; + } + + // Range must be within content length + if (!(0 <= first_pos && first_pos <= last_pos && + last_pos <= content_len - 1)) { + return true; + } + + // Ranges must be in ascending order + if (first_pos <= prev_first_pos) { return true; } + + // Request must not have more than two overlapping ranges + if (first_pos <= prev_last_pos) { + overwrapping_count++; + if (overwrapping_count > 2) { return true; } + } + + prev_first_pos = (std::max)(prev_first_pos, first_pos); + prev_last_pos = (std::max)(prev_last_pos, last_pos); + } + } + + return false; +} + +inline std::pair +get_range_offset_and_length(Range r, size_t content_length) { + assert(r.first != -1 && r.second != -1); + assert(0 <= r.first && r.first < static_cast(content_length)); + assert(r.first <= r.second && + r.second < static_cast(content_length)); + (void)(content_length); + return std::make_pair(r.first, static_cast(r.second - r.first) + 1); +} + +inline std::string make_content_range_header_field( + const std::pair &offset_and_length, size_t content_length) { + auto st = offset_and_length.first; + auto ed = st + offset_and_length.second - 1; + + std::string field = "bytes "; + field += std::to_string(st); + field += "-"; + field += std::to_string(ed); + field += "/"; + field += std::to_string(content_length); + return field; +} + +template +bool process_multipart_ranges_data(const Request &req, + const std::string &boundary, + const std::string &content_type, + size_t content_length, SToken stoken, + CToken ctoken, Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offset_and_length = + get_range_offset_and_length(req.ranges[i], content_length); + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset_and_length, content_length)); + ctoken("\r\n"); + ctoken("\r\n"); + + if (!content(offset_and_length.first, offset_and_length.second)) { + return false; + } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--"); + + return true; +} + +inline void make_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + size_t content_length, + std::string &data) { + process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { data += token; }, + [&](const std::string &token) { data += token; }, + [&](size_t offset, size_t length) { + assert(offset + length <= content_length); + data += res.body.substr(offset, length); + return true; + }); +} + +inline size_t get_multipart_ranges_data_length(const Request &req, + const std::string &boundary, + const std::string &content_type, + size_t content_length) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { data_length += token.size(); }, + [&](const std::string &token) { data_length += token.size(); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; +} + +template +inline bool +write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + size_t content_length, const T &is_shutting_down) { + return process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { strm.write(token); }, + [&](const std::string &token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider_, offset, length, + is_shutting_down); + }); +} + +inline bool expect_content(const Request &req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "DELETE") { + return true; + } + if (req.has_header("Content-Length") && + req.get_header_value_u64("Content-Length") > 0) { + return true; + } + if (is_chunked_transfer_encoding(req.headers)) { return true; } + return false; +} + +inline bool has_crlf(const std::string &s) { + auto p = s.c_str(); + while (*p) { + if (*p == '\r' || *p == '\n') { return true; } + p++; + } + return false; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::string message_digest(const std::string &s, const EVP_MD *algo) { + auto context = std::unique_ptr( + EVP_MD_CTX_new(), EVP_MD_CTX_free); + + unsigned int hash_length = 0; + unsigned char hash[EVP_MAX_MD_SIZE]; + + EVP_DigestInit_ex(context.get(), algo, nullptr); + EVP_DigestUpdate(context.get(), s.c_str(), s.size()); + EVP_DigestFinal_ex(context.get(), hash, &hash_length); + + std::stringstream ss; + for (auto i = 0u; i < hash_length; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << static_cast(hash[i]); + } + + return ss.str(); +} + +inline std::string MD5(const std::string &s) { + return message_digest(s, EVP_md5()); +} + +inline std::string SHA_256(const std::string &s) { + return message_digest(s, EVP_sha256()); +} + +inline std::string SHA_512(const std::string &s) { + return message_digest(s, EVP_sha512()); +} + +inline std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, + size_t cnonce_count, const std::string &cnonce, const std::string &username, + const std::string &password, bool is_proxy = false) { + std::string nc; + { + std::stringstream ss; + ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; + nc = ss.str(); + } + + std::string qop; + if (auth.find("qop") != auth.end()) { + qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else if (qop.find("auth") != std::string::npos) { + qop = "auth"; + } else { + qop.clear(); + } + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + std::string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 + : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + if (qop.empty()) { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); + } else { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + } + + auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; + + auto field = "Digest username=\"" + username + "\", realm=\"" + + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + (qop.empty() ? ", response=\"" + : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + + cnonce + "\", response=\"") + + response + "\"" + + (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} + +inline bool is_ssl_peer_could_be_closed(SSL *ssl, socket_t sock) { + detail::set_nonblocking(sock, true); + auto se = detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + char buf[1]; + return !SSL_peek(ssl, buf, 1) && + SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; +} + +#ifdef _WIN32 +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store +inline bool load_system_certs_on_windows(X509_STORE *store) { + auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); + if (!hStore) { return false; } + + auto result = false; + PCCERT_CONTEXT pContext = NULL; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) { + auto encoded_cert = + static_cast(pContext->pbCertEncoded); + + auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); + + return result; +} +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#if TARGET_OS_OSX +template +using CFObjectPtr = + std::unique_ptr::type, void (*)(CFTypeRef)>; + +inline void cf_object_ptr_deleter(CFTypeRef obj) { + if (obj) { CFRelease(obj); } +} + +inline bool retrieve_certs_from_keychain(CFObjectPtr &certs) { + CFStringRef keys[] = {kSecClass, kSecMatchLimit, kSecReturnRef}; + CFTypeRef values[] = {kSecClassCertificate, kSecMatchLimitAll, + kCFBooleanTrue}; + + CFObjectPtr query( + CFDictionaryCreate(nullptr, reinterpret_cast(keys), values, + sizeof(keys) / sizeof(keys[0]), + &kCFTypeDictionaryKeyCallBacks, + &kCFTypeDictionaryValueCallBacks), + cf_object_ptr_deleter); + + if (!query) { return false; } + + CFTypeRef security_items = nullptr; + if (SecItemCopyMatching(query.get(), &security_items) != errSecSuccess || + CFArrayGetTypeID() != CFGetTypeID(security_items)) { + return false; + } + + certs.reset(reinterpret_cast(security_items)); + return true; +} + +inline bool retrieve_root_certs_from_keychain(CFObjectPtr &certs) { + CFArrayRef root_security_items = nullptr; + if (SecTrustCopyAnchorCertificates(&root_security_items) != errSecSuccess) { + return false; + } + + certs.reset(root_security_items); + return true; +} + +inline bool add_certs_to_x509_store(CFArrayRef certs, X509_STORE *store) { + auto result = false; + for (auto i = 0; i < CFArrayGetCount(certs); ++i) { + const auto cert = reinterpret_cast( + CFArrayGetValueAtIndex(certs, i)); + + if (SecCertificateGetTypeID() != CFGetTypeID(cert)) { continue; } + + CFDataRef cert_data = nullptr; + if (SecItemExport(cert, kSecFormatX509Cert, 0, nullptr, &cert_data) != + errSecSuccess) { + continue; + } + + CFObjectPtr cert_data_ptr(cert_data, cf_object_ptr_deleter); + + auto encoded_cert = static_cast( + CFDataGetBytePtr(cert_data_ptr.get())); + + auto x509 = + d2i_X509(NULL, &encoded_cert, CFDataGetLength(cert_data_ptr.get())); + + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + return result; +} + +inline bool load_system_certs_on_macos(X509_STORE *store) { + auto result = false; + CFObjectPtr certs(nullptr, cf_object_ptr_deleter); + if (retrieve_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store); + } + + if (retrieve_root_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store) || result; + } + + return result; +} +#endif // TARGET_OS_OSX +#endif // _WIN32 +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +#ifdef _WIN32 +class WSInit { +public: + WSInit() { + WSADATA wsaData; + if (WSAStartup(0x0002, &wsaData) == 0) is_valid_ = true; + } + + ~WSInit() { + if (is_valid_) WSACleanup(); + } + + bool is_valid_ = false; +}; + +static WSInit wsinit_; +#endif + +inline bool parse_www_authenticate(const Response &res, + std::map &auth, + bool is_proxy) { + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + thread_local auto re = + std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + const auto &m = *i; + auto key = s.substr(static_cast(m.position(1)), + static_cast(m.length(1))); + auto val = m.length(2) > 0 + ? s.substr(static_cast(m.position(2)), + static_cast(m.length(2))) + : s.substr(static_cast(m.position(3)), + static_cast(m.length(3))); + auth[key] = val; + } + return true; + } + } + } + return false; +} + +class ContentProviderAdapter { +public: + explicit ContentProviderAdapter( + ContentProviderWithoutLength &&content_provider) + : content_provider_(content_provider) {} + + bool operator()(size_t offset, size_t, DataSink &sink) { + return content_provider_(offset, sink); + } + +private: + ContentProviderWithoutLength content_provider_; +}; + +} // namespace detail + +inline std::string hosted_at(const std::string &hostname) { + std::vector addrs; + hosted_at(hostname, addrs); + if (addrs.empty()) { return std::string(); } + return addrs[0]; +} + +inline void hosted_at(const std::string &hostname, + std::vector &addrs) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(hostname.c_str(), nullptr, &hints, &result)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &addr = + *reinterpret_cast(rp->ai_addr); + std::string ip; + auto dummy = -1; + if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), ip, + dummy)) { + addrs.push_back(ip); + } + } +} + +inline std::string append_query_params(const std::string &path, + const Params ¶ms) { + std::string path_with_query = path; + thread_local const std::regex re("[^?]+\\?.*"); + auto delm = std::regex_match(path, re) ? '&' : '?'; + path_with_query += delm + detail::params_to_query_str(params); + return path_with_query; +} + +// Header utilities +inline std::pair +make_range_header(const Ranges &ranges) { + std::string field = "bytes="; + auto i = 0; + for (const auto &r : ranges) { + if (i != 0) { field += ", "; } + if (r.first != -1) { field += std::to_string(r.first); } + field += '-'; + if (r.second != -1) { field += std::to_string(r.second); } + i++; + } + return std::make_pair("Range", std::move(field)); +} + +inline std::pair +make_basic_authentication_header(const std::string &username, + const std::string &password, bool is_proxy) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +inline std::pair +make_bearer_token_authentication_header(const std::string &token, + bool is_proxy = false) { + auto field = "Bearer " + token; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +// Request implementation +inline bool Request::has_header(const std::string &key) const { + return detail::has_header(headers, key); +} + +inline std::string Request::get_header_value(const std::string &key, + const char *def, size_t id) const { + return detail::get_header_value(headers, key, def, id); +} + +inline size_t Request::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Request::set_header(const std::string &key, + const std::string &val) { + if (detail::fields::is_field_name(key) && + detail::fields::is_field_value(val)) { + headers.emplace(key, val); + } +} + +inline bool Request::has_param(const std::string &key) const { + return params.find(key) != params.end(); +} + +inline std::string Request::get_param_value(const std::string &key, + size_t id) const { + auto rng = params.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return std::string(); +} + +inline size_t Request::get_param_value_count(const std::string &key) const { + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline bool Request::is_multipart_form_data() const { + const auto &content_type = get_header_value("Content-Type"); + return !content_type.rfind("multipart/form-data", 0); +} + +inline bool Request::has_file(const std::string &key) const { + return files.find(key) != files.end(); +} + +inline MultipartFormData Request::get_file_value(const std::string &key) const { + auto it = files.find(key); + if (it != files.end()) { return it->second; } + return MultipartFormData(); +} + +inline std::vector +Request::get_file_values(const std::string &key) const { + std::vector values; + auto rng = files.equal_range(key); + for (auto it = rng.first; it != rng.second; it++) { + values.push_back(it->second); + } + return values; +} + +// Response implementation +inline bool Response::has_header(const std::string &key) const { + return headers.find(key) != headers.end(); +} + +inline std::string Response::get_header_value(const std::string &key, + const char *def, + size_t id) const { + return detail::get_header_value(headers, key, def, id); +} + +inline size_t Response::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Response::set_header(const std::string &key, + const std::string &val) { + if (detail::fields::is_field_name(key) && + detail::fields::is_field_value(val)) { + headers.emplace(key, val); + } +} + +inline void Response::set_redirect(const std::string &url, int stat) { + if (detail::fields::is_field_value(url)) { + set_header("Location", url); + if (300 <= stat && stat < 400) { + this->status = stat; + } else { + this->status = StatusCode::Found_302; + } + } +} + +inline void Response::set_content(const char *s, size_t n, + const std::string &content_type) { + body.assign(s, n); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content(const std::string &s, + const std::string &content_type) { + set_content(s.data(), s.size(), content_type); +} + +inline void Response::set_content(std::string &&s, + const std::string &content_type) { + body = std::move(s); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content_provider( + size_t in_length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = in_length; + if (in_length > 0) { content_provider_ = std::move(provider); } + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_chunked_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = true; +} + +inline void Response::set_file_content(const std::string &path, + const std::string &content_type) { + file_content_path_ = path; + file_content_content_type_ = content_type; +} + +inline void Response::set_file_content(const std::string &path) { + file_content_path_ = path; +} + +// Result implementation +inline bool Result::has_request_header(const std::string &key) const { + return request_headers_.find(key) != request_headers_.end(); +} + +inline std::string Result::get_request_header_value(const std::string &key, + const char *def, + size_t id) const { + return detail::get_header_value(request_headers_, key, def, id); +} + +inline size_t +Result::get_request_header_value_count(const std::string &key) const { + auto r = request_headers_.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +// Stream implementation +inline ssize_t Stream::write(const char *ptr) { + return write(ptr, strlen(ptr)); +} + +inline ssize_t Stream::write(const std::string &s) { + return write(s.data(), s.size()); +} + +namespace detail { + +inline void calc_actual_timeout(time_t max_timeout_msec, time_t duration_msec, + time_t timeout_sec, time_t timeout_usec, + time_t &actual_timeout_sec, + time_t &actual_timeout_usec) { + auto timeout_msec = (timeout_sec * 1000) + (timeout_usec / 1000); + + auto actual_timeout_msec = + (std::min)(max_timeout_msec - duration_msec, timeout_msec); + + if (actual_timeout_msec < 0) { actual_timeout_msec = 0; } + + actual_timeout_sec = actual_timeout_msec / 1000; + actual_timeout_usec = (actual_timeout_msec % 1000) * 1000; +} + +// Socket stream implementation +inline SocketStream::SocketStream( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time) + : sock_(sock), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), + max_timeout_msec_(max_timeout_msec), start_time_(start_time), + read_buff_(read_buff_size_, 0) {} + +inline SocketStream::~SocketStream() = default; + +inline bool SocketStream::is_readable() const { + return read_buff_off_ < read_buff_content_size_; +} + +inline bool SocketStream::wait_readable() const { + if (max_timeout_msec_ <= 0) { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + time_t read_timeout_sec; + time_t read_timeout_usec; + calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, + read_timeout_usec_, read_timeout_sec, read_timeout_usec); + + return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; +} + +inline bool SocketStream::wait_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_); +} + +inline ssize_t SocketStream::read(char *ptr, size_t size) { +#ifdef _WIN32 + size = + (std::min)(size, static_cast((std::numeric_limits::max)())); +#else + size = (std::min)(size, + static_cast((std::numeric_limits::max)())); +#endif + + if (read_buff_off_ < read_buff_content_size_) { + auto remaining_size = read_buff_content_size_ - read_buff_off_; + if (size <= remaining_size) { + memcpy(ptr, read_buff_.data() + read_buff_off_, size); + read_buff_off_ += size; + return static_cast(size); + } else { + memcpy(ptr, read_buff_.data() + read_buff_off_, remaining_size); + read_buff_off_ += remaining_size; + return static_cast(remaining_size); + } + } + + if (!wait_readable()) { return -1; } + + read_buff_off_ = 0; + read_buff_content_size_ = 0; + + if (size < read_buff_size_) { + auto n = read_socket(sock_, read_buff_.data(), read_buff_size_, + CPPHTTPLIB_RECV_FLAGS); + if (n <= 0) { + return n; + } else if (n <= static_cast(size)) { + memcpy(ptr, read_buff_.data(), static_cast(n)); + return n; + } else { + memcpy(ptr, read_buff_.data(), size); + read_buff_off_ = size; + read_buff_content_size_ = static_cast(n); + return static_cast(size); + } + } else { + return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + } +} + +inline ssize_t SocketStream::write(const char *ptr, size_t size) { + if (!wait_writable()) { return -1; } + +#if defined(_WIN32) && !defined(_WIN64) + size = + (std::min)(size, static_cast((std::numeric_limits::max)())); +#endif + + return send_socket(sock_, ptr, size, CPPHTTPLIB_SEND_FLAGS); +} + +inline void SocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + return detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + return detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SocketStream::socket() const { return sock_; } + +inline time_t SocketStream::duration() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time_) + .count(); +} + +// Buffer stream implementation +inline bool BufferStream::is_readable() const { return true; } + +inline bool BufferStream::wait_readable() const { return true; } + +inline bool BufferStream::wait_writable() const { return true; } + +inline ssize_t BufferStream::read(char *ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1910 + auto len_read = buffer._Copy_s(ptr, size, size, position); +#else + auto len_read = buffer.copy(ptr, size, position); +#endif + position += static_cast(len_read); + return static_cast(len_read); +} + +inline ssize_t BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast(size); +} + +inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} + +inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} + +inline socket_t BufferStream::socket() const { return 0; } + +inline time_t BufferStream::duration() const { return 0; } + +inline const std::string &BufferStream::get_buffer() const { return buffer; } + +inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { + constexpr const char marker[] = "/:"; + + // One past the last ending position of a path param substring + std::size_t last_param_end = 0; + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + // Needed to ensure that parameter names are unique during matcher + // construction + // If exceptions are disabled, only last duplicate path + // parameter will be set + std::unordered_set param_name_set; +#endif + + while (true) { + const auto marker_pos = pattern.find( + marker, last_param_end == 0 ? last_param_end : last_param_end - 1); + if (marker_pos == std::string::npos) { break; } + + static_fragments_.push_back( + pattern.substr(last_param_end, marker_pos - last_param_end + 1)); + + const auto param_name_start = marker_pos + str_len(marker); + + auto sep_pos = pattern.find(separator, param_name_start); + if (sep_pos == std::string::npos) { sep_pos = pattern.length(); } + + auto param_name = + pattern.substr(param_name_start, sep_pos - param_name_start); + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (param_name_set.find(param_name) != param_name_set.cend()) { + std::string msg = "Encountered path parameter '" + param_name + + "' multiple times in route pattern '" + pattern + "'."; + throw std::invalid_argument(msg); + } +#endif + + param_names_.push_back(std::move(param_name)); + + last_param_end = sep_pos + 1; + } + + if (last_param_end < pattern.length()) { + static_fragments_.push_back(pattern.substr(last_param_end)); + } +} + +inline bool PathParamsMatcher::match(Request &request) const { + request.matches = std::smatch(); + request.path_params.clear(); + request.path_params.reserve(param_names_.size()); + + // One past the position at which the path matched the pattern last time + std::size_t starting_pos = 0; + for (size_t i = 0; i < static_fragments_.size(); ++i) { + const auto &fragment = static_fragments_[i]; + + if (starting_pos + fragment.length() > request.path.length()) { + return false; + } + + // Avoid unnecessary allocation by using strncmp instead of substr + + // comparison + if (std::strncmp(request.path.c_str() + starting_pos, fragment.c_str(), + fragment.length()) != 0) { + return false; + } + + starting_pos += fragment.length(); + + // Should only happen when we have a static fragment after a param + // Example: '/users/:id/subscriptions' + // The 'subscriptions' fragment here does not have a corresponding param + if (i >= param_names_.size()) { continue; } + + auto sep_pos = request.path.find(separator, starting_pos); + if (sep_pos == std::string::npos) { sep_pos = request.path.length(); } + + const auto ¶m_name = param_names_[i]; + + request.path_params.emplace( + param_name, request.path.substr(starting_pos, sep_pos - starting_pos)); + + // Mark everything up to '/' as matched + starting_pos = sep_pos + 1; + } + // Returns false if the path is longer than the pattern + return starting_pos >= request.path.length(); +} + +inline bool RegexMatcher::match(Request &request) const { + request.path_params.clear(); + return std::regex_match(request.path, request.matches, regex_); +} + +} // namespace detail + +// HTTP server implementation +inline Server::Server() + : new_task_queue( + [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }) { +#ifndef _WIN32 + signal(SIGPIPE, SIG_IGN); +#endif +} + +inline Server::~Server() = default; + +inline std::unique_ptr +Server::make_matcher(const std::string &pattern) { + if (pattern.find("/:") != std::string::npos) { + return detail::make_unique(pattern); + } else { + return detail::make_unique(pattern); + } +} + +inline Server &Server::Get(const std::string &pattern, Handler handler) { + get_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, Handler handler) { + post_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, + HandlerWithContentReader handler) { + post_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, Handler handler) { + put_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, + HandlerWithContentReader handler) { + put_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, Handler handler) { + patch_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, + HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, Handler handler) { + delete_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, + HandlerWithContentReader handler) { + delete_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Options(const std::string &pattern, Handler handler) { + options_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline bool Server::set_base_dir(const std::string &dir, + const std::string &mount_point) { + return set_mount_point(mount_point, dir); +} + +inline bool Server::set_mount_point(const std::string &mount_point, + const std::string &dir, Headers headers) { + detail::FileStat stat(dir); + if (stat.is_dir()) { + std::string mnt = !mount_point.empty() ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.push_back({mnt, dir, std::move(headers)}); + return true; + } + } + return false; +} + +inline bool Server::remove_mount_point(const std::string &mount_point) { + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->mount_point == mount_point) { + base_dirs_.erase(it); + return true; + } + } + return false; +} + +inline Server & +Server::set_file_extension_and_mimetype_mapping(const std::string &ext, + const std::string &mime) { + file_extension_and_mimetype_map_[ext] = mime; + return *this; +} + +inline Server &Server::set_default_file_mimetype(const std::string &mime) { + default_file_mimetype_ = mime; + return *this; +} + +inline Server &Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler_core(HandlerWithResponse handler, + std::true_type) { + error_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler_core(Handler handler, + std::false_type) { + error_handler_ = [handler](const Request &req, Response &res) { + handler(req, res); + return HandlerResponse::Handled; + }; + return *this; +} + +inline Server &Server::set_exception_handler(ExceptionHandler handler) { + exception_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_pre_routing_handler(HandlerWithResponse handler) { + pre_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_post_routing_handler(Handler handler) { + post_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_logger(Logger logger) { + logger_ = std::move(logger); + return *this; +} + +inline Server & +Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { + expect_100_continue_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_address_family(int family) { + address_family_ = family; + return *this; +} + +inline Server &Server::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; + return *this; +} + +inline Server &Server::set_ipv6_v6only(bool on) { + ipv6_v6only_ = on; + return *this; +} + +inline Server &Server::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); + return *this; +} + +inline Server &Server::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); + return *this; +} + +inline Server &Server::set_header_writer( + std::function const &writer) { + header_writer_ = writer; + return *this; +} + +inline Server &Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; + return *this; +} + +inline Server &Server::set_keep_alive_timeout(time_t sec) { + keep_alive_timeout_sec_ = sec; + return *this; +} + +inline Server &Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_idle_interval(time_t sec, time_t usec) { + idle_interval_sec_ = sec; + idle_interval_usec_ = usec; + return *this; +} + +inline Server &Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; + return *this; +} + +inline bool Server::bind_to_port(const std::string &host, int port, + int socket_flags) { + auto ret = bind_internal(host, port, socket_flags); + if (ret == -1) { is_decommissioned = true; } + return ret >= 0; +} +inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { + auto ret = bind_internal(host, 0, socket_flags); + if (ret == -1) { is_decommissioned = true; } + return ret; +} + +inline bool Server::listen_after_bind() { return listen_internal(); } + +inline bool Server::listen(const std::string &host, int port, + int socket_flags) { + return bind_to_port(host, port, socket_flags) && listen_internal(); +} + +inline bool Server::is_running() const { return is_running_; } + +inline void Server::wait_until_ready() const { + while (!is_running_ && !is_decommissioned) { + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } +} + +inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + is_decommissioned = false; +} + +inline void Server::decommission() { is_decommissioned = true; } + +inline bool Server::parse_request_line(const char *s, Request &req) const { + auto len = strlen(s); + if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { return false; } + len -= 2; + + { + size_t count = 0; + + detail::split(s, s + len, ' ', [&](const char *b, const char *e) { + switch (count) { + case 0: req.method = std::string(b, e); break; + case 1: req.target = std::string(b, e); break; + case 2: req.version = std::string(b, e); break; + default: break; + } + count++; + }); + + if (count != 3) { return false; } + } + + thread_local const std::set methods{ + "GET", "HEAD", "POST", "PUT", "DELETE", + "CONNECT", "OPTIONS", "TRACE", "PATCH", "PRI"}; + + if (methods.find(req.method) == methods.end()) { return false; } + + if (req.version != "HTTP/1.1" && req.version != "HTTP/1.0") { return false; } + + { + // Skip URL fragment + for (size_t i = 0; i < req.target.size(); i++) { + if (req.target[i] == '#') { + req.target.erase(i); + break; + } + } + + detail::divide(req.target, '?', + [&](const char *lhs_data, std::size_t lhs_size, + const char *rhs_data, std::size_t rhs_size) { + req.path = detail::decode_url( + std::string(lhs_data, lhs_size), false); + detail::parse_query_text(rhs_data, rhs_size, req.params); + }); + } + + return true; +} + +inline bool Server::write_response(Stream &strm, bool close_connection, + Request &req, Response &res) { + // NOTE: `req.ranges` should be empty, otherwise it will be applied + // incorrectly to the error content. + req.ranges.clear(); + return write_response_core(strm, close_connection, req, res, false); +} + +inline bool Server::write_response_with_content(Stream &strm, + bool close_connection, + const Request &req, + Response &res) { + return write_response_core(strm, close_connection, req, res, true); +} + +inline bool Server::write_response_core(Stream &strm, bool close_connection, + const Request &req, Response &res, + bool need_apply_ranges) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_ && + error_handler_(req, res) == HandlerResponse::Handled) { + need_apply_ranges = true; + } + + std::string content_type; + std::string boundary; + if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } + + // Prepare additional headers + if (close_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } else { + std::string s = "timeout="; + s += std::to_string(keep_alive_timeout_sec_); + s += ", max="; + s += std::to_string(keep_alive_max_count_); + res.set_header("Keep-Alive", s); + } + + if ((!res.body.empty() || res.content_length_ > 0 || res.content_provider_) && + !res.has_header("Content-Type")) { + res.set_header("Content-Type", "text/plain"); + } + + if (res.body.empty() && !res.content_length_ && !res.content_provider_ && + !res.has_header("Content-Length")) { + res.set_header("Content-Length", "0"); + } + + if (req.method == "HEAD" && !res.has_header("Accept-Ranges")) { + res.set_header("Accept-Ranges", "bytes"); + } + + if (post_routing_handler_) { post_routing_handler_(req, res); } + + // Response line and headers + { + detail::BufferStream bstrm; + if (!detail::write_response_line(bstrm, res.status)) { return false; } + if (!header_writer_(bstrm, res.headers)) { return false; } + + // Flush buffer + auto &data = bstrm.get_buffer(); + detail::write_data(strm, data.data(), data.size()); + } + + // Body + auto ret = true; + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!detail::write_data(strm, res.body.data(), res.body.size())) { + ret = false; + } + } else if (res.content_provider_) { + if (write_content_with_provider(strm, req, res, boundary, content_type)) { + res.content_provider_success_ = true; + } else { + ret = false; + } + } + } + + // Log + if (logger_) { logger_(req, res); } + + return ret; +} + +inline bool +Server::write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type) { + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + + if (res.content_length_ > 0) { + if (req.ranges.empty()) { + return detail::write_content(strm, res.content_provider_, 0, + res.content_length_, is_shutting_down); + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length( + req.ranges[0], res.content_length_); + + return detail::write_content(strm, res.content_provider_, + offset_and_length.first, + offset_and_length.second, is_shutting_down); + } else { + return detail::write_multipart_ranges_data( + strm, req, res, boundary, content_type, res.content_length_, + is_shutting_down); + } + } else { + if (res.is_chunked_content_provider_) { + auto type = detail::encoding_type(req, res); + + std::unique_ptr compressor; + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Zstd) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + compressor = detail::make_unique(); +#endif + } else { + compressor = detail::make_unique(); + } + assert(compressor != nullptr); + + return detail::write_content_chunked(strm, res.content_provider_, + is_shutting_down, *compressor); + } else { + return detail::write_content_without_length(strm, res.content_provider_, + is_shutting_down); + } + } +} + +inline bool Server::read_content(Stream &strm, Request &req, Response &res) { + MultipartFormDataMap::iterator cur; + auto file_count = 0; + if (read_content_core( + strm, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { return false; } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData &file) { + if (file_count++ == CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT) { + return false; + } + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char *buf, size_t n) { + auto &content = cur->second.content; + if (content.size() + n > content.max_size()) { return false; } + content.append(buf, n); + return true; + })) { + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + if (req.body.size() > CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH) { + res.status = StatusCode::PayloadTooLarge_413; // NOTE: should be 414? + return false; + } + detail::parse_query_text(req.body, req.params); + } + return true; + } + return false; +} + +inline bool Server::read_content_with_content_receiver( + Stream &strm, Request &req, Response &res, ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) { + return read_content_core(strm, req, res, std::move(receiver), + std::move(multipart_header), + std::move(multipart_receiver)); +} + +inline bool +Server::read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) const { + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiverWithProgress out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = StatusCode::BadRequest_400; + return false; + } + + multipart_form_data_parser.set_boundary(std::move(boundary)); + out = [&](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { + /* For debug + size_t pos = 0; + while (pos < n) { + auto read_size = (std::min)(1, n - pos); + auto ret = multipart_form_data_parser.parse( + buf + pos, read_size, multipart_receiver, multipart_header); + if (!ret) { return false; } + pos += read_size; + } + return true; + */ + return multipart_form_data_parser.parse(buf, n, multipart_receiver, + multipart_header); + }; + } else { + out = [receiver](const char *buf, size_t n, uint64_t /*off*/, + uint64_t /*len*/) { return receiver(buf, n); }; + } + + if (req.method == "DELETE" && !req.has_header("Content-Length")) { + return true; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, + out, true)) { + return false; + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = StatusCode::BadRequest_400; + return false; + } + } + + return true; +} + +inline bool Server::handle_file_request(const Request &req, Response &res, + bool head) { + for (const auto &entry : base_dirs_) { + // Prefix match + if (!req.path.compare(0, entry.mount_point.size(), entry.mount_point)) { + std::string sub_path = "/" + req.path.substr(entry.mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = entry.base_dir + sub_path; + if (path.back() == '/') { path += "index.html"; } + + detail::FileStat stat(path); + + if (stat.is_dir()) { + res.set_redirect(sub_path + "/", StatusCode::MovedPermanently_301); + return true; + } + + if (stat.is_file()) { + for (const auto &kv : entry.headers) { + res.set_header(kv.first, kv.second); + } + + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { return false; } + + res.set_content_provider( + mm->size(), + detail::find_content_type(path, file_extension_and_mimetype_map_, + default_file_mimetype_), + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + + return true; + } + } + } + } + return false; +} + +inline socket_t +Server::create_server_socket(const std::string &host, int port, + int socket_flags, + SocketOptions socket_options) const { + return detail::create_socket( + host, std::string(), port, address_family_, socket_flags, tcp_nodelay_, + ipv6_v6only_, std::move(socket_options), + [](socket_t sock, struct addrinfo &ai, bool & /*quit*/) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, CPPHTTPLIB_LISTEN_BACKLOG)) { return false; } + return true; + }); +} + +inline int Server::bind_internal(const std::string &host, int port, + int socket_flags) { + if (is_decommissioned) { return -1; } + + if (!is_valid()) { return -1; } + + svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); + if (svr_sock_ == INVALID_SOCKET) { return -1; } + + if (port == 0) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), + &addr_len) == -1) { + return -1; + } + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return -1; + } + } else { + return port; + } +} + +inline bool Server::listen_internal() { + if (is_decommissioned) { return false; } + + auto ret = true; + is_running_ = true; + auto se = detail::scope_exit([&]() { is_running_ = false; }); + + { + std::unique_ptr task_queue(new_task_queue()); + + while (svr_sock_ != INVALID_SOCKET) { +#ifndef _WIN32 + if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { +#endif + auto val = detail::select_read(svr_sock_, idle_interval_sec_, + idle_interval_usec_); + if (val == 0) { // Timeout + task_queue->on_idle(); + continue; + } +#ifndef _WIN32 + } +#endif + +#if defined _WIN32 + // sockets connected via WASAccept inherit flags NO_HANDLE_INHERIT, + // OVERLAPPED + socket_t sock = WSAAccept(svr_sock_, nullptr, nullptr, nullptr, 0); +#elif defined SOCK_CLOEXEC + socket_t sock = accept4(svr_sock_, nullptr, nullptr, SOCK_CLOEXEC); +#else + socket_t sock = accept(svr_sock_, nullptr, nullptr); +#endif + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } else if (errno == EINTR || errno == EAGAIN) { + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; + } + + detail::set_socket_opt_time(sock, SOL_SOCKET, SO_RCVTIMEO, + read_timeout_sec_, read_timeout_usec_); + detail::set_socket_opt_time(sock, SOL_SOCKET, SO_SNDTIMEO, + write_timeout_sec_, write_timeout_usec_); + + if (!task_queue->enqueue( + [this, sock]() { process_and_close_socket(sock); })) { + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + } + + task_queue->shutdown(); + } + + is_decommissioned = !ret; + return ret; +} + +inline bool Server::routing(Request &req, Response &res, Stream &strm) { + if (pre_routing_handler_ && + pre_routing_handler_(req, res) == HandlerResponse::Handled) { + return true; + } + + // File handler + auto is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && + handle_file_request(req, res, is_head_request)) { + return true; + } + + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver( + strm, req, res, std::move(receiver), nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, nullptr, + std::move(header), + std::move(receiver)); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + patch_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "DELETE") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + delete_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, req, res)) { return false; } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = StatusCode::BadRequest_400; + return false; +} + +inline bool Server::dispatch_request(Request &req, Response &res, + const Handlers &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + handler(req, res); + return true; + } + } + return false; +} + +inline void Server::apply_ranges(const Request &req, Response &res, + std::string &content_type, + std::string &boundary) const { + if (req.ranges.size() > 1 && res.status == StatusCode::PartialContent_206) { + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } + + boundary = detail::make_multipart_data_boundary(); + + res.set_header("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } + + auto type = detail::encoding_type(req, res); + + if (res.body.empty()) { + if (res.content_length_ > 0) { + size_t length = 0; + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { + length = res.content_length_; + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length( + req.ranges[0], res.content_length_); + + length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field( + offset_and_length, res.content_length_); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length( + req, boundary, content_type, res.content_length_); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider_) { + if (res.is_chunked_content_provider_) { + res.set_header("Transfer-Encoding", "chunked"); + if (type == detail::EncodingType::Gzip) { + res.set_header("Content-Encoding", "gzip"); + } else if (type == detail::EncodingType::Brotli) { + res.set_header("Content-Encoding", "br"); + } else if (type == detail::EncodingType::Zstd) { + res.set_header("Content-Encoding", "zstd"); + } + } + } + } + } else { + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { + ; + } else if (req.ranges.size() == 1) { + auto offset_and_length = + detail::get_range_offset_and_length(req.ranges[0], res.body.size()); + auto offset = offset_and_length.first; + auto length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field( + offset_and_length, res.body.size()); + res.set_header("Content-Range", content_range); + + assert(offset + length <= res.body.size()); + res.body = res.body.substr(offset, length); + } else { + std::string data; + detail::make_multipart_ranges_data(req, res, boundary, content_type, + res.body.size(), data); + res.body.swap(data); + } + + if (type != detail::EncodingType::None) { + std::unique_ptr compressor; + std::string content_encoding; + + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); + content_encoding = "gzip"; +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); + content_encoding = "br"; +#endif + } else if (type == detail::EncodingType::Zstd) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + compressor = detail::make_unique(); + content_encoding = "zstd"; +#endif + } + + if (compressor) { + std::string compressed; + if (compressor->compress(res.body.data(), res.body.size(), true, + [&](const char *data, size_t data_len) { + compressed.append(data, data_len); + return true; + })) { + res.body.swap(compressed); + res.set_header("Content-Encoding", content_encoding); + } + } + } + + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } +} + +inline bool Server::dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + handler(req, res, content_reader); + return true; + } + } + return false; +} + +inline bool +Server::process_request(Stream &strm, const std::string &remote_addr, + int remote_port, const std::string &local_addr, + int local_port, bool close_connection, + bool &connection_closed, + const std::function &setup_request) { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { return false; } + + Request req; + + Response res; + res.version = "HTTP/1.1"; + res.headers = default_headers_; + + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = StatusCode::BadRequest_400; + return write_response(strm, close_connection, req, res); + } + + // Check if the request URI doesn't exceed the limit + if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::UriTooLong_414; + return write_response(strm, close_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_closed = true; + } + + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_closed = true; + } + + req.remote_addr = remote_addr; + req.remote_port = remote_port; + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + + req.local_addr = local_addr; + req.local_port = local_port; + req.set_header("LOCAL_ADDR", req.local_addr); + req.set_header("LOCAL_PORT", std::to_string(req.local_port)); + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + } + + if (setup_request) { setup_request(req); } + + if (req.get_header_value("Expect") == "100-continue") { + int status = StatusCode::Continue_100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case StatusCode::Continue_100: + case StatusCode::ExpectationFailed_417: + detail::write_response_line(strm, status); + strm.write("\r\n"); + break; + default: + connection_closed = true; + return write_response(strm, true, req, res); + } + } + + // Setup `is_connection_closed` method + req.is_connection_closed = [&]() { + return !detail::is_socket_alive(strm.socket()); + }; + + // Routing + auto routed = false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS + routed = routing(req, res, strm); +#else + try { + routed = routing(req, res, strm); + } catch (std::exception &e) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + std::string val; + auto s = e.what(); + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case '\r': val += "\\r"; break; + case '\n': val += "\\n"; break; + default: val += s[i]; break; + } + } + res.set_header("EXCEPTION_WHAT", val); + } + } catch (...) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); + } + } +#endif + if (routed) { + if (res.status == -1) { + res.status = req.ranges.empty() ? StatusCode::OK_200 + : StatusCode::PartialContent_206; + } + + // Serve file content by using a content provider + if (!res.file_content_path_.empty()) { + const auto &path = res.file_content_path_; + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::NotFound_404; + return write_response(strm, close_connection, req, res); + } + + auto content_type = res.file_content_content_type_; + if (content_type.empty()) { + content_type = detail::find_content_type( + path, file_extension_and_mimetype_map_, default_file_mimetype_); + } + + res.set_content_provider( + mm->size(), content_type, + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + } + + if (detail::range_error(req, res)) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + + return write_response_with_content(strm, close_connection, req, res); + } else { + if (res.status == -1) { res.status = StatusCode::NotFound_404; } + + return write_response(strm, close_connection, req, res); + } +} + +inline bool Server::is_valid() const { return true; } + +inline bool Server::process_and_close_socket(socket_t sock) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + auto ret = detail::process_server_socket( + svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, connection_closed, + nullptr); + }); + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// HTTP client implementation +inline ClientImpl::ClientImpl(const std::string &host) + : ClientImpl(host, 80, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port) + : ClientImpl(host, port, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : host_(detail::escape_abstract_namespace_unix_domain(host)), port_(port), + host_and_port_(adjust_host_string(host_) + ":" + std::to_string(port)), + client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} + +inline ClientImpl::~ClientImpl() { + // Wait until all the requests in flight are handled. + size_t retry_count = 10; + while (retry_count-- > 0) { + { + std::lock_guard guard(socket_mutex_); + if (socket_requests_in_flight_ == 0) { break; } + } + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } + + std::lock_guard guard(socket_mutex_); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline bool ClientImpl::is_valid() const { return true; } + +inline void ClientImpl::copy_settings(const ClientImpl &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + connection_timeout_sec_ = rhs.connection_timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + write_timeout_sec_ = rhs.write_timeout_sec_; + write_timeout_usec_ = rhs.write_timeout_usec_; + max_timeout_msec_ = rhs.max_timeout_msec_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; + bearer_token_auth_token_ = rhs.bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; +#endif + keep_alive_ = rhs.keep_alive_; + follow_location_ = rhs.follow_location_; + url_encode_ = rhs.url_encode_; + address_family_ = rhs.address_family_; + tcp_nodelay_ = rhs.tcp_nodelay_; + ipv6_v6only_ = rhs.ipv6_v6only_; + socket_options_ = rhs.socket_options_; + compress_ = rhs.compress_; + decompress_ = rhs.decompress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; + proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + ca_cert_file_path_ = rhs.ca_cert_file_path_; + ca_cert_dir_path_ = rhs.ca_cert_dir_path_; + ca_cert_store_ = rhs.ca_cert_store_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + server_certificate_verification_ = rhs.server_certificate_verification_; + server_hostname_verification_ = rhs.server_hostname_verification_; + server_certificate_verifier_ = rhs.server_certificate_verifier_; +#endif + logger_ = rhs.logger_; +} + +inline socket_t ClientImpl::create_client_socket(Error &error) const { + if (!proxy_host_.empty() && proxy_port_ != -1) { + return detail::create_client_socket( + proxy_host_, std::string(), proxy_port_, address_family_, tcp_nodelay_, + ipv6_v6only_, socket_options_, connection_timeout_sec_, + connection_timeout_usec_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, interface_, error); + } + + // Check is custom IP specified for host_ + std::string ip; + auto it = addr_map_.find(host_); + if (it != addr_map_.end()) { ip = it->second; } + + return detail::create_client_socket( + host_, ip, port_, address_family_, tcp_nodelay_, ipv6_v6only_, + socket_options_, connection_timeout_sec_, connection_timeout_usec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, interface_, error); +} + +inline bool ClientImpl::create_and_connect_socket(Socket &socket, + Error &error) { + auto sock = create_client_socket(error); + if (sock == INVALID_SOCKET) { return false; } + socket.sock = sock; + return true; +} + +inline void ClientImpl::shutdown_ssl(Socket & /*socket*/, + bool /*shutdown_gracefully*/) { + // If there are any requests in flight from threads other than us, then it's + // a thread-unsafe race because individual ssl* objects are not thread-safe. + assert(socket_requests_in_flight_ == 0 || + socket_requests_are_from_thread_ == std::this_thread::get_id()); +} + +inline void ClientImpl::shutdown_socket(Socket &socket) const { + if (socket.sock == INVALID_SOCKET) { return; } + detail::shutdown_socket(socket.sock); +} + +inline void ClientImpl::close_socket(Socket &socket) { + // If there are requests in flight in another thread, usually closing + // the socket will be fine and they will simply receive an error when + // using the closed socket, but it is still a bug since rarely the OS + // may reassign the socket id to be used for a new socket, and then + // suddenly they will be operating on a live socket that is different + // than the one they intended! + assert(socket_requests_in_flight_ == 0 || + socket_requests_are_from_thread_ == std::this_thread::get_id()); + + // It is also a bug if this happens while SSL is still active +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + assert(socket.ssl == nullptr); +#endif + if (socket.sock == INVALID_SOCKET) { return; } + detail::close_socket(socket.sock); + socket.sock = INVALID_SOCKET; +} + +inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, + Response &res) const { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + if (!line_reader.getline()) { return false; } + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); +#else + thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); +#endif + + std::cmatch m; + if (!std::regex_match(line_reader.ptr(), m, re)) { + return req.method == "CONNECT"; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + + // Ignore '100 Continue' + while (res.status == StatusCode::Continue_100) { + if (!line_reader.getline()) { return false; } // CRLF + if (!line_reader.getline()) { return false; } // next response line + + if (!std::regex_match(line_reader.ptr(), m, re)) { return false; } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + } + + return true; +} + +inline bool ClientImpl::send(Request &req, Response &res, Error &error) { + std::lock_guard request_mutex_guard(request_mutex_); + auto ret = send_(req, res, error); + if (error == Error::SSLPeerCouldBeClosed_) { + assert(!ret); + ret = send_(req, res, error); + } + return ret; +} + +inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { + { + std::lock_guard guard(socket_mutex_); + + // Set this to false immediately - if it ever gets set to true by the end of + // the request, we know another thread instructed us to close the socket. + socket_should_be_closed_when_request_is_done_ = false; + + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::is_socket_alive(socket_.sock); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_alive && is_ssl()) { + if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + is_alive = false; + } + } +#endif + + if (!is_alive) { + // Attempt to avoid sigpipe by shutting down non-gracefully if it seems + // like the other side has already closed the connection Also, there + // cannot be any requests in flight from other threads since we locked + // request_mutex_, so safe to close everything immediately + const bool shutdown_gracefully = false; + shutdown_ssl(socket_, shutdown_gracefully); + shutdown_socket(socket_); + close_socket(socket_); + } + } + + if (!is_alive) { + if (!create_and_connect_socket(socket_, error)) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // TODO: refactoring + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + auto success = false; + if (!scli.connect_with_proxy(socket_, req.start_time_, res, success, + error)) { + return success; + } + } + + if (!scli.initialize_ssl(socket_, error)) { return false; } + } +#endif + } + + // Mark the current socket as being in use so that it cannot be closed by + // anyone else while this request is ongoing, even though we will be + // releasing the mutex. + if (socket_requests_in_flight_ > 1) { + assert(socket_requests_are_from_thread_ == std::this_thread::get_id()); + } + socket_requests_in_flight_ += 1; + socket_requests_are_from_thread_ = std::this_thread::get_id(); + } + + for (const auto &header : default_headers_) { + if (req.headers.find(header.first) == req.headers.end()) { + req.headers.insert(header); + } + } + + auto ret = false; + auto close_connection = !keep_alive_; + + auto se = detail::scope_exit([&]() { + // Briefly lock mutex in order to mark that a request is no longer ongoing + std::lock_guard guard(socket_mutex_); + socket_requests_in_flight_ -= 1; + if (socket_requests_in_flight_ <= 0) { + assert(socket_requests_in_flight_ == 0); + socket_requests_are_from_thread_ = std::thread::id(); + } + + if (socket_should_be_closed_when_request_is_done_ || close_connection || + !ret) { + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + }); + + ret = process_socket(socket_, req.start_time_, [&](Stream &strm) { + return handle_request(strm, req, res, close_connection, error); + }); + + if (!ret) { + if (error == Error::Success) { error = Error::Unknown; } + } + + return ret; +} + +inline Result ClientImpl::send(const Request &req) { + auto req2 = req; + return send_(std::move(req2)); +} + +inline Result ClientImpl::send_(Request &&req) { + auto res = detail::make_unique(); + auto error = Error::Success; + auto ret = send(req, *res, error); + return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers)}; +} + +inline bool ClientImpl::handle_request(Stream &strm, Request &req, + Response &res, bool close_connection, + Error &error) { + if (req.path.empty()) { + error = Error::Connection; + return false; + } + + auto req_save = req; + + bool ret; + + if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, close_connection, error); + req = req2; + req.path = req_save.path; + } else { + ret = process_request(strm, req, res, close_connection, error); + } + + if (!ret) { return false; } + + if (res.get_header_value("Connection") == "close" || + (res.version == "HTTP/1.0" && res.reason != "Connection established")) { + // TODO this requires a not-entirely-obvious chain of calls to be correct + // for this to be safe. + + // This is safe to call because handle_request is only called by send_ + // which locks the request mutex during the process. It would be a bug + // to call it from a different thread since it's a thread-safety issue + // to do these things to the socket if another thread is using the socket. + std::lock_guard guard(socket_mutex_); + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + + if (300 < res.status && res.status < 400 && follow_location_) { + req = req_save; + ret = redirect(req, res, error); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if ((res.status == StatusCode::Unauthorized_401 || + res.status == StatusCode::ProxyAuthenticationRequired_407) && + req.authorization_count_ < 5) { + auto is_proxy = res.status == StatusCode::ProxyAuthenticationRequired_407; + const auto &username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + new_req.authorization_count_ += 1; + new_req.headers.erase(is_proxy ? "Proxy-Authorization" + : "Authorization"); + new_req.headers.insert(detail::make_digest_authentication_header( + req, auth, new_req.authorization_count_, detail::random_string(10), + username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res, error); + if (ret) { res = new_res; } + } + } + } +#endif + + return ret; +} + +inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { + if (req.redirect_count_ == 0) { + error = Error::ExceedRedirectCount; + return false; + } + + auto location = res.get_header_value("location"); + if (location.empty()) { return false; } + + thread_local const std::regex re( + R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); + + std::smatch m; + if (!std::regex_match(location, m, re)) { return false; } + + auto scheme = is_ssl() ? "https" : "http"; + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + if (next_host.empty()) { next_host = m[3].str(); } + auto port_str = m[4].str(); + auto next_path = m[5].str(); + auto next_query = m[6].str(); + + auto next_port = port_; + if (!port_str.empty()) { + next_port = std::stoi(port_str); + } else if (!next_scheme.empty()) { + next_port = next_scheme == "https" ? 443 : 80; + } + + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_host.empty()) { next_host = host_; } + if (next_path.empty()) { next_path = "/"; } + + auto path = detail::decode_url(next_path, true) + next_query; + + if (next_scheme == scheme && next_host == host_ && next_port == port_) { + return detail::redirect(*this, req, res, path, location, error); + } else { + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host, next_port); + cli.copy_settings(*this); + if (ca_cert_store_) { cli.set_ca_cert_store(ca_cert_store_); } + return detail::redirect(cli, req, res, path, location, error); +#else + return false; +#endif + } else { + ClientImpl cli(next_host, next_port); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, path, location, error); + } + } +} + +inline bool ClientImpl::write_content_with_provider(Stream &strm, + const Request &req, + Error &error) const { + auto is_shutting_down = []() { return false; }; + + if (req.is_chunked_content_provider_) { + // TODO: Brotli support + std::unique_ptr compressor; +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + compressor = detail::make_unique(); + } else +#endif + { + compressor = detail::make_unique(); + } + + return detail::write_content_chunked(strm, req.content_provider_, + is_shutting_down, *compressor, error); + } else { + return detail::write_content(strm, req.content_provider_, 0, + req.content_length_, is_shutting_down, error); + } +} + +inline bool ClientImpl::write_request(Stream &strm, Request &req, + bool close_connection, Error &error) { + // Prepare additional headers + if (close_connection) { + if (!req.has_header("Connection")) { + req.set_header("Connection", "close"); + } + } + + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } else { + if (port_ == 80) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } + } + + if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); } + + if (!req.content_receiver) { + if (!req.has_header("Accept-Encoding")) { + std::string accept_encoding; +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + accept_encoding = "br"; +#endif +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "gzip, deflate"; +#endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "zstd"; +#endif + req.set_header("Accept-Encoding", accept_encoding); + } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!req.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + req.set_header("User-Agent", agent); + } +#endif + }; + + if (req.body.empty()) { + if (req.content_provider_) { + if (!req.is_chunked_content_provider_) { + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.content_length_); + req.set_header("Content-Length", length); + } + } + } else { + if (req.method == "POST" || req.method == "PUT" || + req.method == "PATCH") { + req.set_header("Content-Length", "0"); + } + } + } else { + if (!req.has_header("Content-Type")) { + req.set_header("Content-Type", "text/plain"); + } + + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + req.set_header("Content-Length", length); + } + } + + if (!basic_auth_password_.empty() || !basic_auth_username_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } + } + + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + } + + if (!bearer_token_auth_token_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_bearer_token_authentication_header( + bearer_token_auth_token_, false)); + } + } + + if (!proxy_bearer_token_auth_token_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert(make_bearer_token_authentication_header( + proxy_bearer_token_auth_token_, true)); + } + } + + // Request line and headers + { + detail::BufferStream bstrm; + + const auto &path_with_query = + req.params.empty() ? req.path + : append_query_params(req.path, req.params); + + const auto &path = + url_encode_ ? detail::encode_url(path_with_query) : path_with_query; + + detail::write_request_line(bstrm, req.method, path); + + header_writer_(bstrm, req.headers); + + // Flush buffer + auto &data = bstrm.get_buffer(); + if (!detail::write_data(strm, data.data(), data.size())) { + error = Error::Write; + return false; + } + } + + // Body + if (req.body.empty()) { + return write_content_with_provider(strm, req, error); + } + + if (!detail::write_data(strm, req.body.data(), req.body.size())) { + error = Error::Write; + return false; + } + + return true; +} + +inline std::unique_ptr ClientImpl::send_with_content_provider( + Request &req, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Error &error) { + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { req.set_header("Content-Encoding", "gzip"); } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_ && !content_provider_without_length) { + // TODO: Brotli support + detail::gzip_compressor compressor; + + if (content_provider) { + auto ok = true; + size_t offset = 0; + DataSink data_sink; + + data_sink.write = [&](const char *data, size_t data_len) -> bool { + if (ok) { + auto last = offset + data_len == content_length; + + auto ret = compressor.compress( + data, data_len, last, + [&](const char *compressed_data, size_t compressed_data_len) { + req.body.append(compressed_data, compressed_data_len); + return true; + }); + + if (ret) { + offset += data_len; + } else { + ok = false; + } + } + return ok; + }; + + while (ok && offset < content_length) { + if (!content_provider(offset, content_length - offset, data_sink)) { + error = Error::Canceled; + return nullptr; + } + } + } else { + if (!compressor.compress(body, content_length, true, + [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + })) { + error = Error::Compression; + return nullptr; + } + } + } else +#endif + { + if (content_provider) { + req.content_length_ = content_length; + req.content_provider_ = std::move(content_provider); + req.is_chunked_content_provider_ = false; + } else if (content_provider_without_length) { + req.content_length_ = 0; + req.content_provider_ = detail::ContentProviderAdapter( + std::move(content_provider_without_length)); + req.is_chunked_content_provider_ = true; + req.set_header("Transfer-Encoding", "chunked"); + } else { + req.body.assign(body, content_length); + } + } + + auto res = detail::make_unique(); + return send(req, *res, error) ? std::move(res) : nullptr; +} + +inline Result ClientImpl::send_with_content_provider( + const std::string &method, const std::string &path, const Headers &headers, + const char *body, size_t content_length, ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Progress progress) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; + req.progress = progress; + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + auto error = Error::Success; + + auto res = send_with_content_provider( + req, body, content_length, std::move(content_provider), + std::move(content_provider_without_length), content_type, error); + + return Result{std::move(res), error, std::move(req.headers)}; +} + +inline std::string +ClientImpl::adjust_host_string(const std::string &host) const { + if (host.find(':') != std::string::npos) { return "[" + host + "]"; } + return host; +} + +inline bool ClientImpl::process_request(Stream &strm, Request &req, + Response &res, bool close_connection, + Error &error) { + // Send request + if (!write_request(strm, req, close_connection, error)) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl()) { + auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; + if (!is_proxy_enabled) { + if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + error = Error::SSLPeerCouldBeClosed_; + return false; + } + } + } +#endif + + // Receive response and headers + if (!read_response_line(strm, req, res) || + !detail::read_headers(strm, res.headers)) { + error = Error::Read; + return false; + } + + // Body + if ((res.status != StatusCode::NoContent_204) && req.method != "HEAD" && + req.method != "CONNECT") { + auto redirect = 300 < res.status && res.status < 400 && + res.status != StatusCode::NotModified_304 && + follow_location_; + + if (req.response_handler && !redirect) { + if (!req.response_handler(res)) { + error = Error::Canceled; + return false; + } + } + + auto out = + req.content_receiver + ? static_cast( + [&](const char *buf, size_t n, uint64_t off, uint64_t len) { + if (redirect) { return true; } + auto ret = req.content_receiver(buf, n, off, len); + if (!ret) { error = Error::Canceled; } + return ret; + }) + : static_cast( + [&](const char *buf, size_t n, uint64_t /*off*/, + uint64_t /*len*/) { + assert(res.body.size() + n <= res.body.max_size()); + res.body.append(buf, n); + return true; + }); + + auto progress = [&](uint64_t current, uint64_t total) { + if (!req.progress || redirect) { return true; } + auto ret = req.progress(current, total); + if (!ret) { error = Error::Canceled; } + return ret; + }; + + if (res.has_header("Content-Length")) { + if (!req.content_receiver) { + auto len = res.get_header_value_u64("Content-Length"); + if (len > res.body.max_size()) { + error = Error::Read; + return false; + } + res.body.reserve(static_cast(len)); + } + } + + if (res.status != StatusCode::NotModified_304) { + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), + dummy_status, std::move(progress), + std::move(out), decompress_)) { + if (error != Error::Canceled) { error = Error::Read; } + return false; + } + } + } + + // Log + if (logger_) { logger_(req, res); } + + return true; +} + +inline ContentProviderWithoutLength ClientImpl::get_multipart_content_provider( + const std::string &boundary, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) const { + size_t cur_item = 0; + size_t cur_start = 0; + // cur_item and cur_start are copied to within the std::function and maintain + // state between successive calls + return [&, cur_item, cur_start](size_t offset, + DataSink &sink) mutable -> bool { + if (!offset && !items.empty()) { + sink.os << detail::serialize_multipart_formdata(items, boundary, false); + return true; + } else if (cur_item < provider_items.size()) { + if (!cur_start) { + const auto &begin = detail::serialize_multipart_formdata_item_begin( + provider_items[cur_item], boundary); + offset += begin.size(); + cur_start = offset; + sink.os << begin; + } + + DataSink cur_sink; + auto has_data = true; + cur_sink.write = sink.write; + cur_sink.done = [&]() { has_data = false; }; + + if (!provider_items[cur_item].provider(offset - cur_start, cur_sink)) { + return false; + } + + if (!has_data) { + sink.os << detail::serialize_multipart_formdata_item_end(); + cur_item++; + cur_start = 0; + } + return true; + } else { + sink.os << detail::serialize_multipart_formdata_finish(boundary); + sink.done(); + return true; + } + }; +} + +inline bool ClientImpl::process_socket( + const Socket &socket, + std::chrono::time_point start_time, + std::function callback) { + return detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, max_timeout_msec_, start_time, std::move(callback)); +} + +inline bool ClientImpl::is_ssl() const { return false; } + +inline Result ClientImpl::Get(const std::string &path) { + return Get(path, Headers(), Progress()); +} + +inline Result ClientImpl::Get(const std::string &path, Progress progress) { + return Get(path, Headers(), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers) { + return Get(path, headers, Progress()); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.progress = std::move(progress); + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, + ContentReceiver content_receiver) { + return Get(path, Headers(), nullptr, std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver) { + return Get(path, headers, nullptr, std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, Headers(), std::move(response_handler), + std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, headers, std::move(response_handler), + std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + uint64_t /*offset*/, uint64_t /*total_length*/) { + return content_receiver(data, data_length); + }; + req.progress = std::move(progress); + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, Progress progress) { + if (params.empty()) { return Get(path, headers); } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, params, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + if (params.empty()) { + return Get(path, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); + } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Head(const std::string &path) { + return Head(path, Headers()); +} + +inline Result ClientImpl::Head(const std::string &path, + const Headers &headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Post(const std::string &path) { + return Post(path, std::string(), std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, + const Headers &headers) { + return Post(path, headers, nullptr, 0, std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return Post(path, Headers(), body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, body, content_length, + nullptr, nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("POST", path, headers, body, content_length, + nullptr, nullptr, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const std::string &body, + const std::string &content_type) { + return Post(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return Post(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("POST", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Params ¶ms) { + return Post(path, Headers(), params); +} + +inline Result ClientImpl::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return Post(path, Headers(), content_length, std::move(content_provider), + content_type); +} + +inline Result ClientImpl::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Post(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type, + nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded", + progress); +} + +inline Result ClientImpl::Post(const std::string &path, + const MultipartFormDataItems &items) { + return Post(path, Headers(), items); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type); +} + +inline Result +ClientImpl::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider( + "POST", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type, nullptr); +} + +inline Result ClientImpl::Put(const std::string &path) { + return Put(path, std::string(), std::string()); +} + +inline Result ClientImpl::Put(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return Put(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, body, content_length, + nullptr, nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("PUT", path, headers, body, content_length, + nullptr, nullptr, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const std::string &body, + const std::string &content_type) { + return Put(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return Put(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("PUT", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return Put(path, Headers(), content_length, std::move(content_provider), + content_type); +} + +inline Result ClientImpl::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Put(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type, + nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Params ¶ms) { + return Put(path, Headers(), params); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded", + progress); +} + +inline Result ClientImpl::Put(const std::string &path, + const MultipartFormDataItems &items) { + return Put(path, Headers(), items); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type); +} + +inline Result +ClientImpl::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider( + "PUT", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type, nullptr); +} +inline Result ClientImpl::Patch(const std::string &path) { + return Patch(path, std::string(), std::string()); +} + +inline Result ClientImpl::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return Patch(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return Patch(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return Patch(path, headers, body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("PATCH", path, headers, body, + content_length, nullptr, nullptr, + content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, + const std::string &body, + const std::string &content_type) { + return Patch(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, + const std::string &body, + const std::string &content_type, + Progress progress) { + return Patch(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return Patch(path, headers, body, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("PATCH", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return Patch(path, Headers(), content_length, std::move(content_provider), + content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Patch(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type, + nullptr); +} + +inline Result ClientImpl::Delete(const std::string &path) { + return Delete(path, Headers(), std::string(), std::string()); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers) { + return Delete(path, headers, std::string(), std::string()); +} + +inline Result ClientImpl::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return Delete(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return Delete(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, const char *body, + size_t content_length, + const std::string &content_type) { + return Delete(path, headers, body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + req.progress = progress; + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + req.body.assign(body, content_length); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Delete(const std::string &path, + const std::string &body, + const std::string &content_type) { + return Delete(path, Headers(), body.data(), body.size(), content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, + const std::string &body, + const std::string &content_type, + Progress progress) { + return Delete(path, Headers(), body.data(), body.size(), content_type, + progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type) { + return Delete(path, headers, body.data(), body.size(), content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return Delete(path, headers, body.data(), body.size(), content_type, + progress); +} + +inline Result ClientImpl::Options(const std::string &path) { + return Options(path, Headers()); +} + +inline Result ClientImpl::Options(const std::string &path, + const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.headers = headers; + req.path = path; + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + return send_(std::move(req)); +} + +inline void ClientImpl::stop() { + std::lock_guard guard(socket_mutex_); + + // If there is anything ongoing right now, the ONLY thread-safe thing we can + // do is to shutdown_socket, so that threads using this socket suddenly + // discover they can't read/write any more and error out. Everything else + // (closing the socket, shutting ssl down) is unsafe because these actions are + // not thread-safe. + if (socket_requests_in_flight_ > 0) { + shutdown_socket(socket_); + + // Aside from that, we set a flag for the socket to be closed when we're + // done. + socket_should_be_closed_when_request_is_done_ = true; + return; + } + + // Otherwise, still holding the mutex, we can shut everything down ourselves + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline std::string ClientImpl::host() const { return host_; } + +inline int ClientImpl::port() const { return port_; } + +inline size_t ClientImpl::is_socket_open() const { + std::lock_guard guard(socket_mutex_); + return socket_.is_open(); +} + +inline socket_t ClientImpl::socket() const { return socket_.sock; } + +inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) { + connection_timeout_sec_ = sec; + connection_timeout_usec_ = usec; +} + +inline void ClientImpl::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; +} + +inline void ClientImpl::set_max_timeout(time_t msec) { + max_timeout_msec_ = msec; +} + +inline void ClientImpl::set_basic_auth(const std::string &username, + const std::string &password) { + basic_auth_username_ = username; + basic_auth_password_ = password; +} + +inline void ClientImpl::set_bearer_token_auth(const std::string &token) { + bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_digest_auth(const std::string &username, + const std::string &password) { + digest_auth_username_ = username; + digest_auth_password_ = password; +} +#endif + +inline void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; } + +inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } + +inline void ClientImpl::set_url_encode(bool on) { url_encode_ = on; } + +inline void +ClientImpl::set_hostname_addr_map(std::map addr_map) { + addr_map_ = std::move(addr_map); +} + +inline void ClientImpl::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); +} + +inline void ClientImpl::set_header_writer( + std::function const &writer) { + header_writer_ = writer; +} + +inline void ClientImpl::set_address_family(int family) { + address_family_ = family; +} + +inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } + +inline void ClientImpl::set_ipv6_v6only(bool on) { ipv6_v6only_ = on; } + +inline void ClientImpl::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); +} + +inline void ClientImpl::set_compress(bool on) { compress_ = on; } + +inline void ClientImpl::set_decompress(bool on) { decompress_ = on; } + +inline void ClientImpl::set_interface(const std::string &intf) { + interface_ = intf; +} + +inline void ClientImpl::set_proxy(const std::string &host, int port) { + proxy_host_ = host; + proxy_port_ = port; +} + +inline void ClientImpl::set_proxy_basic_auth(const std::string &username, + const std::string &password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; +} + +inline void ClientImpl::set_proxy_bearer_token_auth(const std::string &token) { + proxy_bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; +} + +inline void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path) { + ca_cert_file_path_ = ca_cert_file_path; + ca_cert_dir_path_ = ca_cert_dir_path; +} + +inline void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store && ca_cert_store != ca_cert_store_) { + ca_cert_store_ = ca_cert_store; + } +} + +inline X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, + std::size_t size) const { + auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); + auto se = detail::scope_exit([&] { BIO_free_all(mem); }); + if (!mem) { return nullptr; } + + auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); + if (!inf) { return nullptr; } + + auto cts = X509_STORE_new(); + if (cts) { + for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { + auto itmp = sk_X509_INFO_value(inf, i); + if (!itmp) { continue; } + + if (itmp->x509) { X509_STORE_add_cert(cts, itmp->x509); } + if (itmp->crl) { X509_STORE_add_crl(cts, itmp->crl); } + } + } + + sk_X509_INFO_pop_free(inf, X509_INFO_free); + return cts; +} + +inline void ClientImpl::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +inline void ClientImpl::enable_server_hostname_verification(bool enabled) { + server_hostname_verification_ = enabled; +} + +inline void ClientImpl::set_server_certificate_verifier( + std::function verifier) { + server_certificate_verifier_ = verifier; +} +#endif + +inline void ClientImpl::set_logger(Logger logger) { + logger_ = std::move(logger); +} + +/* + * SSL Implementation + */ +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace detail { + +template +inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, + U SSL_connect_or_accept, V setup) { + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } + + if (ssl) { + set_nonblocking(sock, true); + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + BIO_set_nbio(bio, 1); + SSL_set_bio(ssl, bio, bio); + + if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + set_nonblocking(sock, false); + return nullptr; + } + BIO_set_nbio(bio, 0); + set_nonblocking(sock, false); + } + + return ssl; +} + +inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, + bool shutdown_gracefully) { + // sometimes we may want to skip this to try to avoid SIGPIPE if we know + // the remote has closed the network connection + // Note that it is not always possible to avoid SIGPIPE, this is merely a + // best-efforts. + if (shutdown_gracefully) { + (void)(sock); + // SSL_shutdown() returns 0 on first call (indicating close_notify alert + // sent) and 1 on subsequent call (indicating close_notify alert received) + if (SSL_shutdown(ssl) == 0) { + // Expected to return 1, but even if it doesn't, we free ssl + SSL_shutdown(ssl); + } + } + + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); +} + +template +bool ssl_connect_or_accept_nonblocking(socket_t sock, SSL *ssl, + U ssl_connect_or_accept, + time_t timeout_sec, + time_t timeout_usec) { + auto res = 0; + while ((res = ssl_connect_or_accept(ssl)) != 1) { + auto err = SSL_get_error(ssl, res); + switch (err) { + case SSL_ERROR_WANT_READ: + if (select_read(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + case SSL_ERROR_WANT_WRITE: + if (select_write(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + default: break; + } + return false; + } + return true; +} + +template +inline bool process_server_socket_ssl( + const std::atomic &svr_sock, SSL *ssl, socket_t sock, + size_t keep_alive_max_count, time_t keep_alive_timeout_sec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core( + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +template +inline bool process_client_socket_ssl( + SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time, T callback) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec, max_timeout_msec, + start_time); + return callback(strm); +} + +// SSL socket stream implementation +inline SSLSocketStream::SSLSocketStream( + socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time) + : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), + max_timeout_msec_(max_timeout_msec), start_time_(start_time) { + SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); +} + +inline SSLSocketStream::~SSLSocketStream() = default; + +inline bool SSLSocketStream::is_readable() const { + return SSL_pending(ssl_) > 0; +} + +inline bool SSLSocketStream::wait_readable() const { + if (max_timeout_msec_ <= 0) { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + time_t read_timeout_sec; + time_t read_timeout_usec; + calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, + read_timeout_usec_, read_timeout_sec, read_timeout_usec); + + return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; +} + +inline bool SSLSocketStream::wait_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_); +} + +inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (wait_readable()) { + auto ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && (err == SSL_ERROR_WANT_READ || + (err == SSL_ERROR_SYSCALL && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_READ) { +#endif + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (wait_readable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret >= 0) { return ret; } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; + } else { + return -1; + } +} + +inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { + if (wait_writable()) { + auto handle_size = static_cast( + std::min(size, (std::numeric_limits::max)())); + + auto ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE || + (err == SSL_ERROR_SYSCALL && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { +#endif + if (wait_writable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret >= 0) { return ret; } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; + } + return -1; +} + +inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SSLSocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SSLSocketStream::socket() const { return sock_; } + +inline time_t SSLSocketStream::duration() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time_) + .count(); +} + +} // namespace detail + +// SSL HTTP server implementation +inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path, + const char *private_key_password) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + if (private_key_password != nullptr && (private_key_password[0] != '\0')) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, + reinterpret_cast(const_cast(private_key_password))); + } + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1 || + SSL_CTX_check_private_key(ctx_) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_store) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + + SSL_CTX_set_verify( + ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer( + const std::function &setup_ssl_ctx_callback) { + ctx_ = SSL_CTX_new(TLS_method()); + if (ctx_) { + if (!setup_ssl_ctx_callback(*ctx_)) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLServer::~SSLServer() { + if (ctx_) { SSL_CTX_free(ctx_); } +} + +inline bool SSLServer::is_valid() const { return ctx_; } + +inline SSL_CTX *SSLServer::ssl_context() const { return ctx_; } + +inline void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + + std::lock_guard guard(ctx_mutex_); + + SSL_CTX_use_certificate(ctx_, cert); + SSL_CTX_use_PrivateKey(ctx_, private_key); + + if (client_ca_cert_store != nullptr) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + } +} + +inline bool SSLServer::process_and_close_socket(socket_t sock) { + auto ssl = detail::ssl_new( + sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + return detail::ssl_connect_or_accept_nonblocking( + sock, ssl2, SSL_accept, read_timeout_sec_, read_timeout_usec_); + }, + [](SSL * /*ssl2*/) { return true; }); + + auto ret = false; + if (ssl) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + ret = detail::process_server_socket_ssl( + svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, + connection_closed, + [&](Request &req) { req.ssl = ssl; }); + }); + + // Shutdown gracefully if the result seemed successful, non-gracefully if + // the connection appeared to be closed. + const bool shutdown_gracefully = ret; + detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully); + } + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// SSL HTTP client implementation +inline SSLClient::SSLClient(const std::string &host) + : SSLClient(host, 443, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port) + : SSLClient(host, port, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path, + const std::string &private_key_password) + : ClientImpl(host, port, client_cert_path, client_key_path) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(b, e); + }); + + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast( + const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::SSLClient(const std::string &host, int port, + X509 *client_cert, EVP_PKEY *client_key, + const std::string &private_key_password) + : ClientImpl(host, port) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(b, e); + }); + + if (client_cert != nullptr && client_key != nullptr) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast( + const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::~SSLClient() { + if (ctx_) { SSL_CTX_free(ctx_); } + // Make sure to shut down SSL since shutdown_ssl will resolve to the + // base function rather than the derived function once we get to the + // base class destructor, and won't free the SSL (causing a leak). + shutdown_ssl_impl(socket_, true); +} + +inline bool SSLClient::is_valid() const { return ctx_; } + +inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store) { + if (ctx_) { + if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) { + // Free memory allocated for old cert and use new store `ca_cert_store` + SSL_CTX_set_cert_store(ctx_, ca_cert_store); + } + } else { + X509_STORE_free(ca_cert_store); + } + } +} + +inline void SSLClient::load_ca_cert_store(const char *ca_cert, + std::size_t size) { + set_ca_cert_store(ClientImpl::create_ca_cert_store(ca_cert, size)); +} + +inline long SSLClient::get_openssl_verify_result() const { + return verify_result_; +} + +inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } + +inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { + return is_valid() && ClientImpl::create_and_connect_socket(socket, error); +} + +// Assumes that socket_mutex_ is locked and that there are no requests in flight +inline bool SSLClient::connect_with_proxy( + Socket &socket, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error) { + success = true; + Response proxy_res; + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + if (max_timeout_msec_ > 0) { + req2.start_time_ = std::chrono::steady_clock::now(); + } + return process_request(strm, req2, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are no + // requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + + if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(proxy_res, auth, true)) { + proxy_res = Response(); + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(detail::make_digest_authentication_header( + req3, auth, 1, detail::random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + if (max_timeout_msec_ > 0) { + req3.start_time_ = std::chrono::steady_clock::now(); + } + return process_request(strm, req3, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + } + } + } + + // If status code is not 200, proxy request is failed. + // Set error to ProxyConnection and return proxy response + // as the response of the request + if (proxy_res.status != StatusCode::OK_200) { + error = Error::ProxyConnection; + res = std::move(proxy_res); + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} + +inline bool SSLClient::load_certs() { + auto ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + if (!ca_cert_file_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), + nullptr)) { + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, nullptr, + ca_cert_dir_path_.c_str())) { + ret = false; + } + } else { + auto loaded = false; +#ifdef _WIN32 + loaded = + detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#if TARGET_OS_OSX + loaded = detail::load_system_certs_on_macos(SSL_CTX_get_cert_store(ctx_)); +#endif // TARGET_OS_OSX +#endif // _WIN32 + if (!loaded) { SSL_CTX_set_default_verify_paths(ctx_); } + } + }); + + return ret; +} + +inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { + auto ssl = detail::ssl_new( + socket.sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + if (server_certificate_verification_) { + if (!load_certs()) { + error = Error::SSLLoadingCerts; + return false; + } + SSL_set_verify(ssl2, SSL_VERIFY_NONE, nullptr); + } + + if (!detail::ssl_connect_or_accept_nonblocking( + socket.sock, ssl2, SSL_connect, connection_timeout_sec_, + connection_timeout_usec_)) { + error = Error::SSLConnection; + return false; + } + + if (server_certificate_verification_) { + auto verification_status = SSLVerifierResponse::NoDecisionMade; + + if (server_certificate_verifier_) { + verification_status = server_certificate_verifier_(ssl2); + } + + if (verification_status == SSLVerifierResponse::CertificateRejected) { + error = Error::SSLServerVerification; + return false; + } + + if (verification_status == SSLVerifierResponse::NoDecisionMade) { + verify_result_ = SSL_get_verify_result(ssl2); + + if (verify_result_ != X509_V_OK) { + error = Error::SSLServerVerification; + return false; + } + + auto server_cert = SSL_get1_peer_certificate(ssl2); + auto se = detail::scope_exit([&] { X509_free(server_cert); }); + + if (server_cert == nullptr) { + error = Error::SSLServerVerification; + return false; + } + + if (server_hostname_verification_) { + if (!verify_host(server_cert)) { + error = Error::SSLServerHostnameVerification; + return false; + } + } + } + } + + return true; + }, + [&](SSL *ssl2) { +#if defined(OPENSSL_IS_BORINGSSL) + SSL_set_tlsext_host_name(ssl2, host_.c_str()); +#else + // NOTE: Direct call instead of using the OpenSSL macro to suppress + // -Wold-style-cast warning + SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, + static_cast(const_cast(host_.c_str()))); +#endif + return true; + }); + + if (ssl) { + socket.ssl = ssl; + return true; + } + + shutdown_socket(socket); + close_socket(socket); + return false; +} + +inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { + shutdown_ssl_impl(socket, shutdown_gracefully); +} + +inline void SSLClient::shutdown_ssl_impl(Socket &socket, + bool shutdown_gracefully) { + if (socket.sock == INVALID_SOCKET) { + assert(socket.ssl == nullptr); + return; + } + if (socket.ssl) { + detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock, + shutdown_gracefully); + socket.ssl = nullptr; + } + assert(socket.ssl == nullptr); +} + +inline bool SSLClient::process_socket( + const Socket &socket, + std::chrono::time_point start_time, + std::function callback) { + assert(socket.ssl); + return detail::process_client_socket_ssl( + socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, start_time, + std::move(callback)); +} + +inline bool SSLClient::is_ssl() const { return true; } + +inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); +} + +inline bool +SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6 = {}; + struct in_addr addr = {}; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_matched = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = + reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); + auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); + + switch (type) { + case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_matched = true; + } + break; + } + } + } + + if (dsn_matched || ip_matched) { ret = true; } + } + + GENERAL_NAMES_free(const_cast( + reinterpret_cast(alt_names))); + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, + size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { return true; } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(b, e); + }); + + if (host_components_.size() != pattern_components.size()) { return false; } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { return false; } + } + ++itr; + } + + return true; +} +#endif + +// Universal client implementation +inline Client::Client(const std::string &scheme_host_port) + : Client(scheme_host_port, std::string(), std::string()) {} + +inline Client::Client(const std::string &scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path) { + const static std::regex re( + R"((?:([a-z]+):\/\/)?(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); + + std::smatch m; + if (std::regex_match(scheme_host_port, m, re)) { + auto scheme = m[1].str(); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (!scheme.empty() && (scheme != "http" && scheme != "https")) { +#else + if (!scheme.empty() && scheme != "http") { +#endif +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + std::string msg = "'" + scheme + "' scheme is not supported."; + throw std::invalid_argument(msg); +#endif + return; + } + + auto is_ssl = scheme == "https"; + + auto host = m[2].str(); + if (host.empty()) { host = m[3].str(); } + + auto port_str = m[4].str(); + auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + + if (is_ssl) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + cli_ = detail::make_unique(host, port, client_cert_path, + client_key_path); + is_ssl_ = is_ssl; +#endif + } else { + cli_ = detail::make_unique(host, port, client_cert_path, + client_key_path); + } + } else { + // NOTE: Update TEST(UniversalClientImplTest, Ipv6LiteralAddress) + // if port param below changes. + cli_ = detail::make_unique(scheme_host_port, 80, + client_cert_path, client_key_path); + } +} // namespace detail + +inline Client::Client(const std::string &host, int port) + : cli_(detail::make_unique(host, port)) {} + +inline Client::Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : cli_(detail::make_unique(host, port, client_cert_path, + client_key_path)) {} + +inline Client::~Client() = default; + +inline bool Client::is_valid() const { + return cli_ != nullptr && cli_->is_valid(); +} + +inline Result Client::Get(const std::string &path) { return cli_->Get(path); } +inline Result Client::Get(const std::string &path, const Headers &headers) { + return cli_->Get(path, headers); +} +inline Result Client::Get(const std::string &path, Progress progress) { + return cli_->Get(path, std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + Progress progress) { + return cli_->Get(path, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, + ContentReceiver content_receiver) { + return cli_->Get(path, std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver) { + return cli_->Get(path, headers, std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, headers, std::move(content_receiver), + std::move(progress)); +} +inline Result Client::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return cli_->Get(path, headers, std::move(response_handler), + std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, Progress progress) { + return cli_->Get(path, params, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, params, headers, std::move(content_receiver), + std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, params, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result Client::Head(const std::string &path) { return cli_->Head(path); } +inline Result Client::Head(const std::string &path, const Headers &headers) { + return cli_->Head(path, headers); +} + +inline Result Client::Post(const std::string &path) { return cli_->Post(path); } +inline Result Client::Post(const std::string &path, const Headers &headers) { + return cli_->Post(path, headers); +} +inline Result Client::Post(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Post(path, body, content_length, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Post(path, headers, body, content_length, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Post(path, headers, body, content_length, content_type, + progress); +} +inline Result Client::Post(const std::string &path, const std::string &body, + const std::string &content_type) { + return cli_->Post(path, body, content_type); +} +inline Result Client::Post(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Post(path, body, content_type, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Post(path, headers, body, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Post(path, headers, body, content_type, progress); +} +inline Result Client::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Post(path, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Post(path, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Post(path, headers, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Post(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, const Params ¶ms) { + return cli_->Post(path, params); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const Params ¶ms) { + return cli_->Post(path, headers, params); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + return cli_->Post(path, headers, params, progress); +} +inline Result Client::Post(const std::string &path, + const MultipartFormDataItems &items) { + return cli_->Post(path, items); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items) { + return cli_->Post(path, headers, items); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + return cli_->Post(path, headers, items, boundary); +} +inline Result +Client::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + return cli_->Post(path, headers, items, provider_items); +} +inline Result Client::Put(const std::string &path) { return cli_->Put(path); } +inline Result Client::Put(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Put(path, body, content_length, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Put(path, headers, body, content_length, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Put(path, headers, body, content_length, content_type, progress); +} +inline Result Client::Put(const std::string &path, const std::string &body, + const std::string &content_type) { + return cli_->Put(path, body, content_type); +} +inline Result Client::Put(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Put(path, body, content_type, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Put(path, headers, body, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Put(path, headers, body, content_type, progress); +} +inline Result Client::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Put(path, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Put(path, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Put(path, headers, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Put(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, const Params ¶ms) { + return cli_->Put(path, params); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const Params ¶ms) { + return cli_->Put(path, headers, params); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + return cli_->Put(path, headers, params, progress); +} +inline Result Client::Put(const std::string &path, + const MultipartFormDataItems &items) { + return cli_->Put(path, items); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items) { + return cli_->Put(path, headers, items); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + return cli_->Put(path, headers, items, boundary); +} +inline Result +Client::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + return cli_->Put(path, headers, items, provider_items); +} +inline Result Client::Patch(const std::string &path) { + return cli_->Patch(path); +} +inline Result Client::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Patch(path, body, content_length, content_type); +} +inline Result Client::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, body, content_length, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Patch(path, headers, body, content_length, content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, headers, body, content_length, content_type, + progress); +} +inline Result Client::Patch(const std::string &path, const std::string &body, + const std::string &content_type) { + return cli_->Patch(path, body, content_type); +} +inline Result Client::Patch(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, body, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Patch(path, headers, body, content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, headers, body, content_type, progress); +} +inline Result Client::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Patch(path, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Patch(path, std::move(content_provider), content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Patch(path, headers, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Patch(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Delete(const std::string &path) { + return cli_->Delete(path); +} +inline Result Client::Delete(const std::string &path, const Headers &headers) { + return cli_->Delete(path, headers); +} +inline Result Client::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Delete(path, body, content_length, content_type); +} +inline Result Client::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, body, content_length, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Delete(path, headers, body, content_length, content_type); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, headers, body, content_length, content_type, + progress); +} +inline Result Client::Delete(const std::string &path, const std::string &body, + const std::string &content_type) { + return cli_->Delete(path, body, content_type); +} +inline Result Client::Delete(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, body, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Delete(path, headers, body, content_type); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, headers, body, content_type, progress); +} +inline Result Client::Options(const std::string &path) { + return cli_->Options(path); +} +inline Result Client::Options(const std::string &path, const Headers &headers) { + return cli_->Options(path, headers); +} + +inline bool Client::send(Request &req, Response &res, Error &error) { + return cli_->send(req, res, error); +} + +inline Result Client::send(const Request &req) { return cli_->send(req); } + +inline void Client::stop() { cli_->stop(); } + +inline std::string Client::host() const { return cli_->host(); } + +inline int Client::port() const { return cli_->port(); } + +inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); } + +inline socket_t Client::socket() const { return cli_->socket(); } + +inline void +Client::set_hostname_addr_map(std::map addr_map) { + cli_->set_hostname_addr_map(std::move(addr_map)); +} + +inline void Client::set_default_headers(Headers headers) { + cli_->set_default_headers(std::move(headers)); +} + +inline void Client::set_header_writer( + std::function const &writer) { + cli_->set_header_writer(writer); +} + +inline void Client::set_address_family(int family) { + cli_->set_address_family(family); +} + +inline void Client::set_tcp_nodelay(bool on) { cli_->set_tcp_nodelay(on); } + +inline void Client::set_socket_options(SocketOptions socket_options) { + cli_->set_socket_options(std::move(socket_options)); +} + +inline void Client::set_connection_timeout(time_t sec, time_t usec) { + cli_->set_connection_timeout(sec, usec); +} + +inline void Client::set_read_timeout(time_t sec, time_t usec) { + cli_->set_read_timeout(sec, usec); +} + +inline void Client::set_write_timeout(time_t sec, time_t usec) { + cli_->set_write_timeout(sec, usec); +} + +inline void Client::set_basic_auth(const std::string &username, + const std::string &password) { + cli_->set_basic_auth(username, password); +} +inline void Client::set_bearer_token_auth(const std::string &token) { + cli_->set_bearer_token_auth(token); +} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_digest_auth(username, password); +} +#endif + +inline void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); } +inline void Client::set_follow_location(bool on) { + cli_->set_follow_location(on); +} + +inline void Client::set_url_encode(bool on) { cli_->set_url_encode(on); } + +inline void Client::set_compress(bool on) { cli_->set_compress(on); } + +inline void Client::set_decompress(bool on) { cli_->set_decompress(on); } + +inline void Client::set_interface(const std::string &intf) { + cli_->set_interface(intf); +} + +inline void Client::set_proxy(const std::string &host, int port) { + cli_->set_proxy(host, port); +} +inline void Client::set_proxy_basic_auth(const std::string &username, + const std::string &password) { + cli_->set_proxy_basic_auth(username, password); +} +inline void Client::set_proxy_bearer_token_auth(const std::string &token) { + cli_->set_proxy_bearer_token_auth(token); +} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_proxy_digest_auth(username, password); +} +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::enable_server_certificate_verification(bool enabled) { + cli_->enable_server_certificate_verification(enabled); +} + +inline void Client::enable_server_hostname_verification(bool enabled) { + cli_->enable_server_hostname_verification(enabled); +} + +inline void Client::set_server_certificate_verifier( + std::function verifier) { + cli_->set_server_certificate_verifier(verifier); +} +#endif + +inline void Client::set_logger(Logger logger) { + cli_->set_logger(std::move(logger)); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path) { + cli_->set_ca_cert_path(ca_cert_file_path, ca_cert_dir_path); +} + +inline void Client::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_store(ca_cert_store); + } else { + cli_->set_ca_cert_store(ca_cert_store); + } +} + +inline void Client::load_ca_cert_store(const char *ca_cert, std::size_t size) { + set_ca_cert_store(cli_->create_ca_cert_store(ca_cert, size)); +} + +inline long Client::get_openssl_verify_result() const { + if (is_ssl_) { + return static_cast(*cli_).get_openssl_verify_result(); + } + return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? +} + +inline SSL_CTX *Client::ssl_context() const { + if (is_ssl_) { return static_cast(*cli_).ssl_context(); } + return nullptr; +} +#endif + +// ---------------------------------------------------------------------------- + +} // namespace httplib + +#endif // CPPHTTPLIB_HTTPLIB_H diff --git a/common/icsearch.hpp b/common/icsearch.hpp new file mode 100644 index 0000000..38a57b2 --- /dev/null +++ b/common/icsearch.hpp @@ -0,0 +1,255 @@ +#pragma once +#include +#include +#include +#include +#include +#include "logger.hpp" + + +namespace bite_im{ +bool Serialize(const Json::Value &val, std::string &dst) +{ + //先定义Json::StreamWriter 工厂类 Json::StreamWriterBuilder + Json::StreamWriterBuilder swb; + swb.settings_["emitUTF8"] = true; + std::unique_ptr sw(swb.newStreamWriter()); + //通过Json::StreamWriter中的write接口进行序列化 + std::stringstream ss; + int ret = sw->write(val, &ss); + if (ret != 0) { + std::cout << "Json反序列化失败!\n"; + return false; + } + dst = ss.str(); + return true; +} +bool UnSerialize(const std::string &src, Json::Value &val) +{ + Json::CharReaderBuilder crb; + std::unique_ptr cr(crb.newCharReader()); + std::string err; + bool ret = cr->parse(src.c_str(), src.c_str() + src.size(), &val, &err); + if (ret == false) { + std::cout << "json反序列化失败: " << err << std::endl; + return false; + } + return true; +} + +class ESIndex { + public: + ESIndex(std::shared_ptr &client, + const std::string &name, + const std::string &type = "_doc"): + _name(name), _type(type), _client(client) { + Json::Value analysis; + Json::Value analyzer; + Json::Value ik; + Json::Value tokenizer; + tokenizer["tokenizer"] = "ik_max_word"; + ik["ik"] = tokenizer; + analyzer["analyzer"] = ik; + analysis["analysis"] = analyzer; + _index["settings"] = analysis; + } + ESIndex& append(const std::string &key, + const std::string &type = "text", + const std::string &analyzer = "ik_max_word", + bool enabled = true) { + Json::Value fields; + fields["type"] = type; + fields["analyzer"] = analyzer; + if (enabled == false ) fields["enabled"] = enabled; + _properties[key] = fields; + return *this; + } + bool create(const std::string &index_id = "default_index_id") { + Json::Value mappings; + mappings["dynamic"] = true; + mappings["properties"] = _properties; + _index["mappings"] = mappings; + + std::string body; + bool ret = Serialize(_index, body); + if (ret == false) { + LOG_ERROR("索引序列化失败!"); + return false; + } + // LOG_DEBUG("{}", body); + //2. 发起搜索请求 + try { + auto rsp = _client->index(_name, _type, index_id, body); + if (rsp.status_code < 200 || rsp.status_code >= 300) { + LOG_ERROR("创建ES索引 {} 失败,响应状态码异常: {}", _name, rsp.status_code); + return false; + } + } catch(std::exception &e) { + LOG_ERROR("创建ES索引 {} 失败: {}", _name, e.what()); + return false; + } + return true; + } + private: + std::string _name; + std::string _type; + Json::Value _properties; + Json::Value _index; + std::shared_ptr _client; +}; + +class ESInsert { + public: + ESInsert(std::shared_ptr &client, + const std::string &name, + const std::string &type = "_doc"): + _name(name), _type(type), _client(client) + {} + + template + ESInsert &append(const std::string &key, const T &val){ + _item[key] = val; + return *this; + } + + bool insert(const std::string id = "") { + std::string body; + bool ret = Serialize(_item, body); + if (ret == false) { + LOG_ERROR("索引序列化失败!"); + return false; + } + LOG_DEBUG("{}", body); + //2. 发起搜索请求 + try { + auto rsp = _client->index(_name, _type, id, body); + if (rsp.status_code < 200 || rsp.status_code >= 300) { + LOG_ERROR("新增数据 {} 失败,响应状态码异常: {}", body, rsp.status_code); + return false; + } + } catch(std::exception &e) { + LOG_ERROR("新增数据 {} 失败: {}", body, e.what()); + return false; + } + return true; + } + private: + std::string _name; + std::string _type; + Json::Value _item; + std::shared_ptr _client; +}; + +class ESRemove { + public: + ESRemove(std::shared_ptr &client, + const std::string &name, + const std::string &type = "_doc"): + _name(name), _type(type), _client(client){} + bool remove(const std::string &id) { + try { + auto rsp = _client->remove(_name, _type, id); + if (rsp.status_code < 200 || rsp.status_code >= 300) { + LOG_ERROR("删除数据 {} 失败,响应状态码异常: {}", id, rsp.status_code); + return false; + } + } catch(std::exception &e) { + LOG_ERROR("删除数据 {} 失败: {}", id, e.what()); + return false; + } + return true; + } + private: + std::string _name; + std::string _type; + std::shared_ptr _client; +}; + +class ESSearch { + public: + ESSearch(std::shared_ptr &client, + const std::string &name, + const std::string &type = "_doc"): + _name(name), _type(type), _client(client){} + ESSearch& append_must_not_terms(const std::string &key, const std::vector &vals) { + Json::Value fields; + for (const auto& val : vals){ + fields[key].append(val); + } + Json::Value terms; + terms["terms"] = fields; + _must_not.append(terms); + return *this; + } + ESSearch& append_should_match(const std::string &key, const std::string &val) { + Json::Value field; + field[key] = val; + Json::Value match; + match["match"] = field; + _should.append(match); + return *this; + } + ESSearch& append_must_term(const std::string &key, const std::string &val) { + Json::Value field; + field[key] = val; + Json::Value term; + term["term"] = field; + _must.append(term); + return *this; + } + ESSearch& append_must_match(const std::string &key, const std::string &val){ + Json::Value field; + field[key] = val; + Json::Value match; + match["match"] = field; + _must.append(match); + return *this; + } + Json::Value search(){ + Json::Value cond; + if (_must_not.empty() == false) cond["must_not"] = _must_not; + if (_should.empty() == false) cond["should"] = _should; + if (_must.empty() == false) cond["must"] = _must; + Json::Value query; + query["bool"] = cond; + Json::Value root; + root["query"] = query; + + std::string body; + bool ret = Serialize(root, body); + if (ret == false) { + LOG_ERROR("索引序列化失败!"); + return Json::Value(); + } + LOG_DEBUG("{}", body); + //2. 发起搜索请求 + cpr::Response rsp; + try { + rsp = _client->search(_name, _type, body); + if (rsp.status_code < 200 || rsp.status_code >= 300) { + LOG_ERROR("检索数据 {} 失败,响应状态码异常: {}", body, rsp.status_code); + return Json::Value(); + } + } catch(std::exception &e) { + LOG_ERROR("检索数据 {} 失败: {}", body, e.what()); + return Json::Value(); + } + //3. 需要对响应正文进行反序列化 + LOG_DEBUG("检索响应正文: [{}]", rsp.text); + Json::Value json_res; + ret = UnSerialize(rsp.text, json_res); + if (ret == false) { + LOG_ERROR("检索数据 {} 结果反序列化失败", rsp.text); + return Json::Value(); + } + return json_res["hits"]["hits"]; + } + private: + std::string _name; + std::string _type; + Json::Value _must_not; + Json::Value _should; + Json::Value _must; + std::shared_ptr _client; +}; +} \ No newline at end of file diff --git a/common/logger.hpp b/common/logger.hpp new file mode 100644 index 0000000..3f0bec5 --- /dev/null +++ b/common/logger.hpp @@ -0,0 +1,34 @@ +#pragma once +#include +#include +#include +#include +#include + +// mode - 运行模式: true-发布模式; false调试模式 + +namespace bite_im{ +std::shared_ptr g_default_logger; +void init_logger(bool mode, const std::string &file, int32_t level) +{ + if (mode == false) { + //如果是调试模式,则创建标准输出日志器,输出等级为最低 + g_default_logger = spdlog::stdout_color_mt("default-logger"); + g_default_logger->set_level(spdlog::level::level_enum::trace); + g_default_logger->flush_on(spdlog::level::level_enum::trace); + }else { + //否则是发布模式,则创建文件输出日志器,输出等级根据参数而定 + g_default_logger = spdlog::basic_logger_mt("default-logger", file); + g_default_logger->set_level((spdlog::level::level_enum)level); + g_default_logger->flush_on((spdlog::level::level_enum)level); + } + g_default_logger->set_pattern("[%n][%H:%M:%S][%t][%-8l]%v"); +} + +#define LOG_TRACE(format, ...) bite_im::g_default_logger->trace(std::string("[{}:{}] ") + format, __FILE__, __LINE__, ##__VA_ARGS__) +#define LOG_DEBUG(format, ...) bite_im::g_default_logger->debug(std::string("[{}:{}] ") + format, __FILE__, __LINE__, ##__VA_ARGS__) +#define LOG_INFO(format, ...) bite_im::g_default_logger->info(std::string("[{}:{}] ") + format, __FILE__, __LINE__, ##__VA_ARGS__) +#define LOG_WARN(format, ...) bite_im::g_default_logger->warn(std::string("[{}:{}] ") + format, __FILE__, __LINE__, ##__VA_ARGS__) +#define LOG_ERROR(format, ...) bite_im::g_default_logger->error(std::string("[{}:{}] ") + format, __FILE__, __LINE__, ##__VA_ARGS__) +#define LOG_FATAL(format, ...) bite_im::g_default_logger->critical(std::string("[{}:{}] ") + format, __FILE__, __LINE__, ##__VA_ARGS__) +} \ No newline at end of file diff --git a/common/mysql.hpp b/common/mysql.hpp new file mode 100644 index 0000000..a1f8384 --- /dev/null +++ b/common/mysql.hpp @@ -0,0 +1,30 @@ +#pragma once +#include +#include // std::auto_ptr +#include // std::exit +#include +#include +#include +#include "logger.hpp" + +// 用户注册, 用户登录, 验证码获取, 手机号注册,手机号登录, 获取用户信息, 用户信息修改 +// 用信息新增, 通过昵称获取用户信息,通过手机号获取用户信息, 通过用户ID获取用户信息, 通过多个用户ID获取多个用户信息,信息修改 +namespace bite_im { +class ODBFactory { + public: + static std::shared_ptr create( + const std::string &user, + const std::string &pswd, + const std::string &host, + const std::string &db, + const std::string &cset, + int port, + int conn_pool_count) { + std::unique_ptr cpf( + new odb::mysql::connection_pool_factory(conn_pool_count, 0)); + auto res = std::make_shared(user, pswd, + db, host, port, "", cset, 0, std::move(cpf)); + return res; + } +}; +} \ No newline at end of file diff --git a/common/mysql_apply.hpp b/common/mysql_apply.hpp new file mode 100644 index 0000000..7547f10 --- /dev/null +++ b/common/mysql_apply.hpp @@ -0,0 +1,71 @@ +#pragma once +#include "mysql.hpp" +#include "friend_apply.hxx" +#include "friend_apply-odb.hxx" + +namespace bite_im { + class FriendApplyTable { + public: + using ptr = std::shared_ptr; + FriendApplyTable(const std::shared_ptr &db) : _db(db){} + bool insert(FriendApply &ev) { + try { + odb::transaction trans(_db->begin()); + _db->persist(ev); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("新增好友申请事件失败 {}-{}:{}!", ev.user_id(), ev.peer_id(), e.what()); + return false; + } + return true; + } + bool exists(const std::string &uid, const std::string &pid) { + bool flag = false; + try { + typedef odb::query query; + typedef odb::result result; + odb::transaction trans(_db->begin()); + result r(_db->query(query::user_id == uid && query::peer_id == pid)); + LOG_DEBUG("{} - {} 好友事件数量:{}", uid, pid, r.size()); + flag = !r.empty(); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("获取好友申请事件失败:{}-{}-{}!", uid, pid, e.what()); + } + return flag; + } + bool remove(const std::string &uid, const std::string &pid) { + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + _db->erase_query(query::user_id == uid && query::peer_id == pid); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("删除好友申请事件失败 {}-{}:{}!", uid, pid, e.what()); + return false; + } + return true; + } + //获取当前指定用户的 所有好友申请者ID + std::vector applyUsers(const std::string &uid){ + std::vector res; + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + //当前的uid是被申请者的用户ID + result r(_db->query(query::peer_id == uid)); + for (result::iterator i(r.begin()); i != r.end(); ++i) { + res.push_back(i->user_id()); + } + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("通过用户{}的好友申请者失败:{}!", uid, e.what()); + } + return res; + } + private: + std::shared_ptr _db; + }; +} \ No newline at end of file diff --git a/common/mysql_chat_session.hpp b/common/mysql_chat_session.hpp new file mode 100644 index 0000000..906063d --- /dev/null +++ b/common/mysql_chat_session.hpp @@ -0,0 +1,118 @@ +#pragma once +#include "mysql.hpp" +#include "chat_session.hxx" +#include "chat_session-odb.hxx" +#include "mysql_chat_session_member.hpp" + +namespace bite_im { + class ChatSessionTable { + public: + using ptr = std::shared_ptr; + ChatSessionTable(const std::shared_ptr &db):_db(db){} + bool insert(ChatSession &cs) { + try { + odb::transaction trans(_db->begin()); + _db->persist(cs); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("新增会话失败 {}:{}!", cs.chat_session_name(), e.what()); + return false; + } + return true; + } + bool remove(const std::string &ssid) { + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + _db->erase_query(query::chat_session_id == ssid); + + typedef odb::query mquery; + _db->erase_query(mquery::session_id == ssid); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("删除会话失败 {}:{}!", ssid, e.what()); + return false; + } + return true; + } + bool remove(const std::string &uid, const std::string &pid) { + //单聊会话的删除,-- 根据单聊会话的两个成员 + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + auto res = _db->query_one( + query::csm1::user_id == uid && + query::csm2::user_id == pid && + query::css::chat_session_type == ChatSessionType::SINGLE); + + std::string cssid = res->chat_session_id; + typedef odb::query cquery; + _db->erase_query(cquery::chat_session_id == cssid); + + typedef odb::query mquery; + _db->erase_query(mquery::session_id == cssid); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("删除会话失败 {}-{}:{}!", uid, pid, e.what()); + return false; + } + return true; + } + std::shared_ptr select(const std::string &ssid) { + std::shared_ptr res; + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + res.reset(_db->query_one(query::chat_session_id == ssid)); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("通过会话ID获取会话信息失败 {}:{}!", ssid, e.what()); + } + return res; + } + std::vector singleChatSession(const std::string &uid) { + std::vector res; + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + //当前的uid是被申请者的用户ID + result r(_db->query( + query::css::chat_session_type == ChatSessionType::SINGLE && + query::csm1::user_id == uid && + query::csm2::user_id != query::csm1::user_id)); + for (result::iterator i(r.begin()); i != r.end(); ++i) { + res.push_back(*i); + } + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("获取用户 {} 的单聊会话失败:{}!", uid, e.what()); + } + return res; + } + std::vector groupChatSession(const std::string &uid) { + std::vector res; + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + //当前的uid是被申请者的用户ID + result r(_db->query( + query::css::chat_session_type == ChatSessionType::GROUP && + query::csm::user_id == uid )); + for (result::iterator i(r.begin()); i != r.end(); ++i) { + res.push_back(*i); + } + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("获取用户 {} 的群聊会话失败:{}!", uid, e.what()); + } + return res; + } + private: + std::shared_ptr _db; + }; +} \ No newline at end of file diff --git a/common/mysql_chat_session_member.hpp b/common/mysql_chat_session_member.hpp new file mode 100644 index 0000000..3bc0796 --- /dev/null +++ b/common/mysql_chat_session_member.hpp @@ -0,0 +1,87 @@ +#pragma once +#include "mysql.hpp" +#include "chat_session_member.hxx" +#include "chat_session_member-odb.hxx" + +namespace bite_im { +class ChatSessionMemeberTable { + public: + using ptr = std::shared_ptr; + ChatSessionMemeberTable(const std::shared_ptr &db):_db(db){} + //单个会话成员的新增 --- ssid & uid + bool append(ChatSessionMember &csm) { + try { + odb::transaction trans(_db->begin()); + _db->persist(csm); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("新增单会话成员失败 {}-{}:{}!", + csm.session_id(), csm.user_id(), e.what()); + return false; + } + return true; + } + bool append(std::vector &csm_lists) { + try { + odb::transaction trans(_db->begin()); + for (auto &csm : csm_lists) { + _db->persist(csm); + } + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("新增多会话成员失败 {}-{}:{}!", + csm_lists[0].session_id(), csm_lists.size(), e.what()); + return false; + } + return true; + } + //删除指定会话中的指定成员 -- ssid & uid + bool remove(ChatSessionMember &csm) { + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + _db->erase_query(query::session_id == csm.session_id() && + query::user_id == csm.user_id()); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("删除单会话成员失败 {}-{}:{}!", + csm.session_id(), csm.user_id(), e.what()); + return false; + } + return true; + } + //删除会话的所有成员信息 + bool remove(const std::string &ssid) { + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + _db->erase_query(query::session_id == ssid); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("删除会话所有成员失败 {}:{}!", ssid, e.what()); + return false; + } + return true; + } + std::vector members(const std::string &ssid) { + std::vector res; + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + result r(_db->query(query::session_id == ssid)); + for (result::iterator i(r.begin()); i != r.end(); ++i) { + res.push_back(i->user_id()); + } + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("获取会话成员失败:{}-{}!", ssid, e.what()); + } + return res; + } + private: + std::shared_ptr _db; +}; +} \ No newline at end of file diff --git a/common/mysql_message.hpp b/common/mysql_message.hpp new file mode 100644 index 0000000..23b0480 --- /dev/null +++ b/common/mysql_message.hpp @@ -0,0 +1,86 @@ +#include "mysql.hpp" +#include "message.hxx" +#include "message-odb.hxx" + +namespace bite_im { +class MessageTable { + public: + using ptr = std::shared_ptr; + MessageTable(const std::shared_ptr &db): _db(db){} + ~MessageTable(){} + bool insert(Message &msg) { + try { + odb::transaction trans(_db->begin()); + _db->persist(msg); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("新增消息失败 {}:{}!", msg.message_id(),e.what()); + return false; + } + return true; + } + + bool remove(const std::string &ssid) { + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + _db->erase_query(query::session_id == ssid); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("删除会话所有消息失败 {}:{}!", ssid, e.what()); + return false; + } + return true; + } + + std::vector recent(const std::string &ssid, int count) { + std::vector res; + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + //本次查询是以ssid作为过滤条件,然后进行以时间字段进行逆序,通过limit + // session_id='xx' order by create_time desc limit count; + std::stringstream cond; + cond << "session_id='" << ssid << "' "; + cond << "order by create_time desc limit " << count; + result r(_db->query(cond.str())); + for (result::iterator i(r.begin()); i != r.end(); ++i) { + res.push_back(*i); + } + std::reverse(res.begin(), res.end()); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("获取最近消息失败:{}-{}-{}!", ssid, count, e.what()); + } + return res; + } + + std::vector range(const std::string &ssid, + boost::posix_time::ptime &stime, + boost::posix_time::ptime &etime) { + std::vector res; + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + //获取指定会话指定时间段的信息 + result r(_db->query(query::session_id == ssid && + query::create_time >= stime && + query::create_time <= etime)); + for (result::iterator i(r.begin()); i != r.end(); ++i) { + res.push_back(*i); + } + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("获取区间消息失败:{}-{}:{}-{}!", ssid, + boost::posix_time::to_simple_string(stime), + boost::posix_time::to_simple_string(etime), e.what()); + } + return res; + } + private: + std::shared_ptr _db; +}; +} \ No newline at end of file diff --git a/common/mysql_relation.hpp b/common/mysql_relation.hpp new file mode 100644 index 0000000..ff21388 --- /dev/null +++ b/common/mysql_relation.hpp @@ -0,0 +1,78 @@ +#pragma once +#include "mysql.hpp" +#include "relation.hxx" +#include "relation-odb.hxx" + +namespace bite_im { + class RelationTable { + public: + using ptr = std::shared_ptr; + RelationTable(const std::shared_ptr &db) : _db(db){} + //新增关系信息 + bool insert(const std::string &uid, const std::string &pid) { + //{1,2} {2,1} + try { + Relation r1(uid, pid); + Relation r2(pid, uid); + odb::transaction trans(_db->begin()); + _db->persist(r1); + _db->persist(r2); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("新增用户好友关系信息失败 {}-{}:{}!", uid, pid, e.what()); + return false; + } + return true; + } + //移除关系信息 + bool remove(const std::string &uid, const std::string &pid) { + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + _db->erase_query(query::user_id == uid && query::peer_id == pid); + _db->erase_query(query::user_id == pid && query::peer_id == uid); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("删除好友关系信息失败 {}-{}:{}!", uid, pid, e.what()); + return false; + } + return true; + } + //判断关系是否存在 + bool exists(const std::string &uid, const std::string &pid) { + typedef odb::query query; + typedef odb::result result; + result r; + bool flag = false; + try { + odb::transaction trans(_db->begin()); + r = _db->query(query::user_id == uid && query::peer_id == pid); + flag = !r.empty(); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("获取用户好友关系失败:{}-{}-{}!", uid, pid, e.what()); + } + return flag; + } + //获取指定用户的好友ID + std::vector friends(const std::string &uid) { + std::vector res; + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + result r(_db->query(query::user_id == uid)); + for (result::iterator i(r.begin()); i != r.end(); ++i) { + res.push_back(i->peer_id()); + } + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("通过用户-{}的所有好友ID失败:{}!", uid, e.what()); + } + return res; + } + private: + std::shared_ptr _db; + }; +} \ No newline at end of file diff --git a/common/mysql_user.hpp b/common/mysql_user.hpp new file mode 100644 index 0000000..d4bad74 --- /dev/null +++ b/common/mysql_user.hpp @@ -0,0 +1,102 @@ +#include "mysql.hpp" +#include "user.hxx" +#include "user-odb.hxx" + +namespace bite_im { +class UserTable { + public: + using ptr = std::shared_ptr; + UserTable(const std::shared_ptr &db):_db(db){} + bool insert(const std::shared_ptr &user) { + try { + odb::transaction trans(_db->begin()); + _db->persist(*user); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("新增用户失败 {}:{}!", user->nickname(),e.what()); + return false; + } + return true; + } + bool update(const std::shared_ptr &user) { + try { + odb::transaction trans(_db->begin()); + _db->update(*user); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("更新用户失败 {}:{}!", user->nickname(), e.what()); + return false; + } + return true; + } + std::shared_ptr select_by_nickname(const std::string &nickname) { + std::shared_ptr res; + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + res.reset(_db->query_one(query::nickname == nickname)); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("通过昵称查询用户失败 {}:{}!", nickname, e.what()); + } + return res; + } + std::shared_ptr select_by_phone(const std::string &phone) { + std::shared_ptr res; + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + res.reset(_db->query_one(query::phone == phone)); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("通过手机号查询用户失败 {}:{}!", phone, e.what()); + } + return res; + } + std::shared_ptr select_by_id(const std::string &user_id) { + std::shared_ptr res; + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + res.reset(_db->query_one(query::user_id == user_id)); + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("通过用户ID查询用户失败 {}:{}!", user_id, e.what()); + } + return res; + } + std::vector select_multi_users(const std::vector &id_list) { + // select * from user where id in ('id1', 'id2', ...) + if (id_list.empty()) { + return std::vector(); + } + std::vector res; + try { + odb::transaction trans(_db->begin()); + typedef odb::query query; + typedef odb::result result; + std::stringstream ss; + ss << "user_id in ("; + for (const auto &id : id_list) { + ss << "'" << id << "',"; + } + std::string condition = ss.str(); + condition.pop_back(); + condition += ")"; + result r(_db->query(condition)); + for (result::iterator i(r.begin()); i != r.end(); ++i) { + res.push_back(*i); + } + trans.commit(); + }catch (std::exception &e) { + LOG_ERROR("通过用户ID批量查询用户失败:{}!", e.what()); + } + return res; + } + private: + std::shared_ptr _db; +}; +} \ No newline at end of file diff --git a/common/rabbitmq.hpp b/common/rabbitmq.hpp new file mode 100644 index 0000000..c013535 --- /dev/null +++ b/common/rabbitmq.hpp @@ -0,0 +1,106 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "logger.hpp" + +namespace bite_im{ +class MQClient { + public: + using MessageCallback = std::function; + using ptr = std::shared_ptr; + MQClient(const std::string &user, + const std::string passwd, + const std::string host) { + _loop = EV_DEFAULT; + _handler = std::make_unique(_loop); + //amqp://root:123456@127.0.0.1:5672/ + std::string url = "amqp://" + user + ":" + passwd + "@" + host + "/"; + AMQP::Address address(url); + _connection = std::make_unique(_handler.get(), address); + _channel = std::make_unique(_connection.get()); + + _loop_thread = std::thread([this]() { + ev_run(_loop, 0); + }); + } + ~MQClient() { + ev_async_init(&_async_watcher, watcher_callback); + ev_async_start(_loop, &_async_watcher); + ev_async_send(_loop, &_async_watcher); + _loop_thread.join(); + _loop = nullptr; + } + void declareComponents(const std::string &exchange, + const std::string &queue, + const std::string &routing_key = "routing_key", + AMQP::ExchangeType echange_type = AMQP::ExchangeType::direct) { + _channel->declareExchange(exchange, echange_type) + .onError([](const char *message) { + LOG_ERROR("声明交换机失败:{}", message); + exit(0); + }) + .onSuccess([exchange](){ + LOG_ERROR("{} 交换机创建成功!", exchange); + }); + _channel->declareQueue(queue) + .onError([](const char *message) { + LOG_ERROR("声明队列失败:{}", message); + exit(0); + }) + .onSuccess([queue](){ + LOG_ERROR("{} 队列创建成功!", queue); + }); + //6. 针对交换机和队列进行绑定 + _channel->bindQueue(exchange, queue, routing_key) + .onError([exchange, queue](const char *message) { + LOG_ERROR("{} - {} 绑定失败:", exchange, queue); + exit(0); + }) + .onSuccess([exchange, queue, routing_key](){ + LOG_ERROR("{} - {} - {} 绑定成功!", exchange, queue, routing_key); + }); + } + + bool publish(const std::string &exchange, + const std::string &msg, + const std::string &routing_key = "routing_key") { + LOG_DEBUG("向交换机 {}-{} 发布消息!", exchange, routing_key); + bool ret = _channel->publish(exchange, routing_key, msg); + if (ret == false) { + LOG_ERROR("{} 发布消息失败:", exchange); + return false; + } + return true; + } + void consume(const std::string &queue, const MessageCallback &cb) { + LOG_DEBUG("开始订阅 {} 队列消息!", queue); + _channel->consume(queue, "consume-tag") //返回值 DeferredConsumer + .onReceived([this, cb](const AMQP::Message &message, + uint64_t deliveryTag, + bool redelivered) { + cb(message.body(), message.bodySize()); + _channel->ack(deliveryTag); + }) + .onError([queue](const char *message){ + LOG_ERROR("订阅 {} 队列消息失败: {}", queue, message); + exit(0); + }); + } + private: + static void watcher_callback(struct ev_loop *loop, ev_async *watcher, int32_t revents) { + ev_break(loop, EVBREAK_ALL); + } + private: + struct ev_async _async_watcher; + struct ev_loop *_loop; + std::unique_ptr _handler; + std::unique_ptr _connection; + std::unique_ptr _channel; + std::thread _loop_thread; +}; +} \ No newline at end of file diff --git a/common/sendemail.hpp b/common/sendemail.hpp new file mode 100644 index 0000000..ec56e28 --- /dev/null +++ b/common/sendemail.hpp @@ -0,0 +1,262 @@ +#pragma once + +#include +#include +#include +#include +#include +#include // for memset +#include +#include +#include +#include +#include +#include + +class SendEmail +{ +// 请确保在链接器设置中加入 libssl.lib 与 libcrypto.lib(或通过 vcpkg 自动集成) +public: + using ptr = std::shared_ptr; + + // 辅助函数:通过 SSL 发送数据,并检查返回值 + bool sendSSL(SSL* ssl, const std::string& msg, const char* label) + { + int bytesSent = SSL_write(ssl, msg.c_str(), static_cast(msg.length())); + if (bytesSent <= 0) { + std::cerr << "Failed to send " << label << " message via SSL." << std::endl; + ERR_print_errors_fp(stderr); + return false; + } + return true; + } + + std::string base64_encode(const std::string& input) + { + BIO* bmem = BIO_new(BIO_s_mem()); + BIO* b64 = BIO_new(BIO_f_base64()); + BIO_set_flags(b64, BIO_FLAGS_BASE64_NO_NL); + b64 = BIO_push(b64, bmem); + BIO_write(b64, input.c_str(), static_cast(input.length())); + BIO_flush(b64); + + BUF_MEM* bptr; + BIO_get_mem_ptr(b64, &bptr); + + std::string result(bptr->data, bptr->length); + BIO_free_all(b64); + return result; + } + + // 辅助函数:通过 SSL 接收数据,并打印输出 + bool recvSSL(SSL* ssl, char* buff, int buffSize, const char* label) + { + int bytesReceived = SSL_read(ssl, buff, buffSize); + if (bytesReceived <= 0) { + std::cerr << "Failed to receive " << label << " message via SSL." << std::endl; + ERR_print_errors_fp(stderr); + return false; + } + buff[bytesReceived] = '\0'; + std::cout << label << " response: " << buff << std::endl; + return true; + } + + /// + /// 发送邮件到指定邮箱 + /// + /// + /// 成功 = true,失败 = false + bool SEND_email(const std::string& email) + { + // 1. DNS解析 smtp.qq.com(使用端口 465) + struct addrinfo hints = {}; + hints.ai_family = AF_INET; // IPv4 + hints.ai_socktype = SOCK_STREAM; // TCP + + struct addrinfo* addrResult = nullptr; + int ret = getaddrinfo("smtp.qq.com", "465", &hints, &addrResult); + if (ret != 0 || addrResult == nullptr) { + std::cerr << "DNS resolution failed: " << gai_strerror(ret) << std::endl; + return false; + } + + // 2. 建立 TCP 连接 + int sock = socket(addrResult->ai_family, addrResult->ai_socktype, addrResult->ai_protocol); + if (sock == -1) { + std::cerr << "Socket creation failed!" << std::endl; + freeaddrinfo(addrResult); + return false; + } + + if (connect(sock, addrResult->ai_addr, addrResult->ai_addrlen) != 0) { + std::cerr << "Connection failed!" << std::endl; + close(sock); + freeaddrinfo(addrResult); + return false; + } + freeaddrinfo(addrResult); + + // 3. 初始化 OpenSSL + SSL_library_init(); + SSL_load_error_strings(); + OpenSSL_add_all_algorithms(); + + const SSL_METHOD* method = TLS_client_method(); + SSL_CTX* sslCtx = SSL_CTX_new(method); + if (!sslCtx) { + std::cerr << "Unable to create SSL context." << std::endl; + close(sock); + return false; + } + + SSL* ssl = SSL_new(sslCtx); + if (!ssl) { + std::cerr << "Unable to create SSL structure." << std::endl; + SSL_CTX_free(sslCtx); + close(sock); + return false; + } + + // 绑定 socket 到 SSL + SSL_set_fd(ssl, sock); + if (SSL_connect(ssl) <= 0) { + std::cerr << "SSL_connect failed." << std::endl; + ERR_print_errors_fp(stderr); + SSL_free(ssl); + SSL_CTX_free(sslCtx); + close(sock); + return false; + } + + // 4. 读取服务器欢迎信息 + char buff[2048] = { 0 }; + if (!recvSSL(ssl, buff, sizeof(buff) - 1, "Server")) { + SSL_free(ssl); + SSL_CTX_free(sslCtx); + close(sock); + return false; + } + + // 5. 发送 EHLO 命令 + std::string sendMsg = "EHLO localhost\r\n"; + if (!sendSSL(ssl, sendMsg, "EHLO") || !recvSSL(ssl, buff, sizeof(buff) - 1, "EHLO")) { + SSL_free(ssl); + SSL_CTX_free(sslCtx); + close(sock); + return false; + } + + // 6. 认证流程:AUTH LOGIN + sendMsg = "AUTH LOGIN\r\n"; + if (!sendSSL(ssl, sendMsg, "AUTH") || !recvSSL(ssl, buff, sizeof(buff) - 1, "AUTH")) { + SSL_free(ssl); + SSL_CTX_free(sslCtx); + close(sock); + return false; + } + + // 7. 发送用户名(Base64编码后的) + std::string username = base64_encode("zxiao_xin@qq.com") + "\r\n"; + if (!sendSSL(ssl, username, "Username") || !recvSSL(ssl, buff, sizeof(buff) - 1, "Username")) { + SSL_free(ssl); + SSL_CTX_free(sslCtx); + close(sock); + return false; + } + + // 8. 发送密码(Base64编码后的) + std::string password = base64_encode("ydfslvabdryvejai") + "\r\n"; + if (!sendSSL(ssl, password, "Password") || !recvSSL(ssl, buff, sizeof(buff) - 1, "Password")) { + SSL_free(ssl); + SSL_CTX_free(sslCtx); + close(sock); + return false; + } + + // 9. 设置发件人 + sendMsg = "MAIL FROM:\r\n"; + if (!sendSSL(ssl, sendMsg, "MAIL FROM") || !recvSSL(ssl, buff, sizeof(buff) - 1, "MAIL FROM")) { + SSL_free(ssl); + SSL_CTX_free(sslCtx); + close(sock); + return false; + } + + // 10. 设置收件人 + sendMsg = "RCPT TO:<" + email + ">\r\n"; + if (!sendSSL(ssl, sendMsg, "RCPT TO") || !recvSSL(ssl, buff, sizeof(buff) - 1, "RCPT TO")) { + SSL_free(ssl); + SSL_CTX_free(sslCtx); + close(sock); + return false; + } + + // 检查回复是否为成功代码 + std::string s = buff; + s = s.substr(0, 3); + if (s != "250") { + SSL_free(ssl); + SSL_CTX_free(sslCtx); + close(sock); + return false; + } + + // 11. 命令 DATA + sendMsg = "DATA\r\n"; + if (!sendSSL(ssl, sendMsg, "DATA") || !recvSSL(ssl, buff, sizeof(buff) - 1, "DATA")) { + SSL_free(ssl); + SSL_CTX_free(sslCtx); + close(sock); + return false; + } + + // 12. 生成验证码 + std::default_random_engine engine(static_cast(std::time(nullptr))); + std::uniform_int_distribution distribution(10000, 99999); + std::string verificationCode = std::to_string(distribution(engine)); + + _verifyCode = verificationCode; + + // 13. 构造邮件头和正文 + sendMsg = + "From: \"Mysterious系统\" \r\n" + "Reply-To: \"请勿回复\" \r\n" + "To: <" + email + ">\r\n" + "Subject: Mysterious验证码\r\n" + "Content-Type: text/plain; charset=UTF-8\r\n" + "\r\n" + "这是您的验证码: " + verificationCode + "\r\n" + "用于身份验证,请勿泄露。如非本人操作,请忽略此邮件。\r\n" + "\r\n.\r\n"; + + if (!sendSSL(ssl, sendMsg, "Email DATA") || !recvSSL(ssl, buff, sizeof(buff) - 1, "DATA send")) { + SSL_free(ssl); + SSL_CTX_free(sslCtx); + close(sock); + return false; + } + + // 14. 命令 QUIT + sendMsg = "QUIT\r\n"; + sendSSL(ssl, sendMsg, "QUIT"); + recvSSL(ssl, buff, sizeof(buff) - 1, "QUIT"); + + // 15. 清理资源 + SSL_shutdown(ssl); + SSL_free(ssl); + SSL_CTX_free(sslCtx); + close(sock); + + return true; + } + + std::string getVerifyCode() { + return _verifyCode; + } + +private: + // std::string _email; + std::string _verifyCode; +}; + diff --git a/common/utils.hpp b/common/utils.hpp new file mode 100644 index 0000000..e9a3e80 --- /dev/null +++ b/common/utils.hpp @@ -0,0 +1,86 @@ +//实现项目中一些公共的工具类接口 +//1. 生成一个唯一ID的接口 +//2. 文件的读写操作接口 + +#include +#include +#include +#include +#include +#include +#include +#include "logger.hpp" + +namespace bite_im { + +std::string uuid() { + //生成一个由16位随机字符组成的字符串作为唯一ID + // 1. 生成6个0~255之间的随机数字(1字节-转换为16进制字符)--生成12位16进制字符 + std::random_device rd;//实例化设备随机数对象-用于生成设备随机数 + std::mt19937 generator(rd());//以设备随机数为种子,实例化伪随机数对象 + std::uniform_int_distribution distribution(0,255); //限定数据范围 + + std::stringstream ss; + for (int i = 0; i < 6; i++) { + if (i == 2) ss << "-"; + ss << std::setw(2) << std::setfill('0') << std::hex << distribution(generator); + } + ss << "-"; + // 2. 通过一个静态变量生成一个2字节的编号数字--生成4位16进制数字字符 + static std::atomic idx(0); + short tmp = idx.fetch_add(1); + ss << std::setw(4) << std::setfill('0') << std::hex << tmp; + return ss.str(); +} + +std::string vcode() { + std::random_device rd;//实例化设备随机数对象-用于生成设备随机数 + std::mt19937 generator(rd());//以设备随机数为种子,实例化伪随机数对象 + std::uniform_int_distribution distribution(0,9); //限定数据范围 + + std::stringstream ss; + for (int i = 0; i < 4; i++) { + ss << distribution(generator); + } + return ss.str(); +} + +bool readFile(const std::string &filename, std::string &body){ + //实现读取一个文件的所有数据,放入body中 + std::ifstream ifs(filename, std::ios::binary | std::ios::in); + if (ifs.is_open() == false) { + LOG_ERROR("打开文件 {} 失败!", filename); + return false; + } + ifs.seekg(0, std::ios::end);//跳转到文件末尾 + size_t flen = ifs.tellg(); //获取当前偏移量-- 文件大小 + ifs.seekg(0, std::ios::beg);//跳转到文件起始 + body.resize(flen); + ifs.read(&body[0], flen); + if (ifs.good() == false) { + LOG_ERROR("读取文件 {} 数据失败!", filename); + ifs.close(); + return false; + } + ifs.close(); + return true; +} + +bool writeFile(const std::string &filename, const std::string &body){ + //实现将body中的数据,写入filename对应的文件中 + std::ofstream ofs(filename, std::ios::out | std::ios::binary | std::ios::trunc); + if (ofs.is_open() == false) { + LOG_ERROR("打开文件 {} 失败!", filename); + return false; + } + ofs.write(body.c_str(), body.size()); + if (ofs.good() == false) { + LOG_ERROR("读取文件 {} 数据失败!", filename); + ofs.close(); + return false; + } + ofs.close(); + return true; +} + +} \ No newline at end of file diff --git a/conf/file_server.conf b/conf/file_server.conf new file mode 100644 index 0000000..eab3d48 --- /dev/null +++ b/conf/file_server.conf @@ -0,0 +1,11 @@ +-run_mode=true +-log_file=/im/logs/file.log +-log_level=0 +-registry_host=http://10.0.0.235:2379 +-base_service=/service +-instance_name=/file_service/instance +-access_host=10.0.0.235:10002 +-storage_path=/im/data/ +-listen_port=10002 +-rpc_timeout=-1 +-rpc_threads=1 \ No newline at end of file diff --git a/conf/friend_server.conf b/conf/friend_server.conf new file mode 100644 index 0000000..b911ffc --- /dev/null +++ b/conf/friend_server.conf @@ -0,0 +1,20 @@ +-run_mode=true +-log_file=/im/logs/friend.log +-log_level=0 +-registry_host=http://10.0.0.235:2379 +-instance_name=/friend_service/instance +-access_host=10.0.0.235:10006 +-listen_port=10006 +-rpc_timeout=-1 +-rpc_threads=1 +-base_service=/service +-user_service=/service/user_service +-message_service=/service/message_service +-es_host=http://10.0.0.235:9200/ +-mysql_host=10.0.0.235 +-mysql_user=root +-mysql_pswd=123456 +-mysql_db=bite_im +-mysql_cset=utf8 +-mysql_port=0 +-mysql_pool_count=4 \ No newline at end of file diff --git a/conf/gateway_server.conf b/conf/gateway_server.conf new file mode 100644 index 0000000..8c98598 --- /dev/null +++ b/conf/gateway_server.conf @@ -0,0 +1,17 @@ +-run_mode=true +-log_file=/im/logs/gateway.log +-log_level=0 +-http_listen_port=9000 +-websocket_listen_port=9001 +-registry_host=http://10.0.0.235:2379 +-base_service=/service +-file_service=/service/file_service +-friend_service=/service/friend_service +-message_service=/service/message_service +-user_service=/service/user_service +-speech_service=/service/speech_service +-transmite_service=/service/transmite_service +-redis_host=10.0.0.235 +-redis_port=6379 +-redis_db=0 +-redis_keep_alive=true \ No newline at end of file diff --git a/conf/message_server.conf b/conf/message_server.conf new file mode 100644 index 0000000..d4f22cc --- /dev/null +++ b/conf/message_server.conf @@ -0,0 +1,26 @@ +-run_mode=true +-log_file=/im/logs/message.log +-log_level=0 +-registry_host=http://10.0.0.235:2379 +-instance_name=/message_service/instance +-access_host=10.0.0.235:10005 +-listen_port=10005 +-rpc_timeout=-1 +-rpc_threads=1 +-base_service=/service +-user_service=/service/user_service +-file_service=/service/file_service +-es_host=http://10.0.0.235:9200/ +-mysql_host=10.0.0.235 +-mysql_user=root +-mysql_pswd=123456 +-mysql_db=bite_im +-mysql_cset=utf8 +-mysql_port=0 +-mysql_pool_count=4 +-mq_user=root +-mq_pswd=123456 +-mq_host=10.0.0.235:5672 +-mq_msg_exchange=msg_exchange +-mq_msg_queue=msg_queue +-mq_msg_binding_key=msg_queue \ No newline at end of file diff --git a/conf/speech_server.conf b/conf/speech_server.conf new file mode 100644 index 0000000..dd259c8 --- /dev/null +++ b/conf/speech_server.conf @@ -0,0 +1,13 @@ +-run_mode=true +-log_file=/im/logs/speech.log +-log_level=0 +-registry_host=http://10.0.0.235:2379 +-instance_name=/speech_service/instance +-access_host=10.0.0.235:10001 +-base_service=/service +-listen_port=10001 +-rpc_timeout=-1 +-rpc_threads=1 +-app_id=60694095 +-api_key=PWn6zlsxym8VwpBW8Or4PPGe +-secret_key=Bl0mn74iyAkr3FzCo5TZV7lBq7NYoms9 \ No newline at end of file diff --git a/conf/transmite_server.conf b/conf/transmite_server.conf new file mode 100644 index 0000000..5fec82b --- /dev/null +++ b/conf/transmite_server.conf @@ -0,0 +1,24 @@ +-run_mode=true +-log_file=/im/logs/transmite.log +-log_level=0 +-registry_host=http://10.0.0.235:2379 +-instance_name=/transmite_service/instance +-access_host=10.0.0.235:10004 +-listen_port=10004 +-rpc_timeout=-1 +-rpc_threads=1 +-base_service=/service +-user_service=/service/user_service +-mysql_host=10.0.0.235 +-mysql_user=root +-mysql_pswd=123456 +-mysql_db=bite_im +-mysql_cset=utf8 +-mysql_port=0 +-mysql_pool_count=4 +-mq_user=root +-mq_pswd=123456 +-mq_host=10.0.0.235:5672 +-mq_msg_exchange=msg_exchange +-mq_msg_queue=msg_queue +-mq_msg_binding_key=msg_queue \ No newline at end of file diff --git a/conf/user_server.conf b/conf/user_server.conf new file mode 100644 index 0000000..8d6816a --- /dev/null +++ b/conf/user_server.conf @@ -0,0 +1,25 @@ +-run_mode=true +-log_file=/im/logs/user.log +-log_level=0 +-registry_host=http://10.0.0.235:2379 +-instance_name=/user_service/instance +-access_host=10.0.0.235:10003 +-listen_port=10003 +-rpc_timeout=-1 +-rpc_threads=1 +-base_service=/service +-file_service=/service/file_service +-es_host=http://10.0.0.235:9200/ +-mysql_host=10.0.0.235 +-mysql_user=root +-mysql_pswd=123456 +-mysql_db=bite_im +-mysql_cset=utf8 +-mysql_port=0 +-mysql_pool_count=4 +-redis_host=10.0.0.235 +-redis_port=6379 +-redis_db=0 +-redis_keep_alive=true +-dms_key_id=LTAI5t6NF7vt499UeqYX6LB9 +-dms_key_secret=5hx1qvpXHDKfQDk73aJs6j53Q8KcF2 \ No newline at end of file diff --git a/depends.sh b/depends.sh new file mode 100755 index 0000000..88e9c93 --- /dev/null +++ b/depends.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +#传递两个参数: +# 1. 可执行程序的路径名 +# 2. 目录名称 --- 将这个程序的依赖库拷贝到指定目录下 +declare depends +get_depends() { + depends=$(ldd $1 | awk '{if (match($3,"/")){print $3}}') + #mkdir $2 + cp -Lr $depends $2 +} + +get_depends ./gateway/build/gateway_server ./gateway/depends +get_depends ./file/build/file_server ./file/depends +get_depends ./friend/build/friend_server ./friend/depends +get_depends ./message/build/message_server ./message/depends +get_depends ./speech/build/speech_server ./speech/depends +get_depends ./transmite/build/transmite_server ./transmite/depends +get_depends ./user/build/user_server ./user/depends + +cp /bin/nc ./gateway/ +cp /bin/nc ./file/ +cp /bin/nc ./friend/ +cp /bin/nc ./message/ +cp /bin/nc ./speech/ +cp /bin/nc ./transmite/ +cp /bin/nc ./user/ +get_depends /bin/nc ./gateway/depends +get_depends /bin/nc ./file/depends +get_depends /bin/nc ./friend/depends +get_depends /bin/nc ./message/depends +get_depends /bin/nc ./speech/depends +get_depends /bin/nc ./user/depends +get_depends /bin/nc ./transmite/depends diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..bdcdf08 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,221 @@ +version: "3.8" + +services: + etcd: + image: quay.io/coreos/etcd:v3.5.0 + container_name: etcd-service + environment: + - ETCD_NAME=etcd-s1 + - ETCD_DATA_DIR=/var/lib/etcd + - ETCD_LISTEN_CLIENT_URLS=http://0.0.0.0:2379 + - ETCD_ADVERTISE_CLIENT_URLS=http://0.0.0.0:2379 + volumes: + # 1. 希望容器内的程序能够访问宿主机上的文件 + # 2. 希望容器内程序运行所产生的数据文件能落在宿主机上 + - ./middle/data/etcd:/var/lib/etcd:rw + ports: + - 2379:2379 + restart: always + mysql: + image: mysql:8.0.42 + container_name: mysql-service + environment: + MYSQL_ROOT_PASSWORD: 123456 + volumes: + # 1. 希望容器内的程序能够访问宿主机上的文件 + # 2. 希望容器内程序运行所产生的数据文件能落在宿主机上 + - ./sql:/docker-entrypoint-initdb.d/:rw + - ./middle/data/mysql:/var/lib/mysql:rw + ports: + - 3306:3306 + restart: always + redis: + image: redis:7.0.15 + container_name: redis-service + volumes: + # 1. 希望容器内的程序能够访问宿主机上的文件 + # 2. 希望容器内程序运行所产生的数据文件能落在宿主机上 + - ./middle/data/redis:/var/lib/redis:rw + ports: + - 6379:6379 + restart: always + elasticsearch: + image: elasticsearch:7.17.21 + container_name: elasticsearch-service + environment: + - "discovery.type=single-node" + volumes: + # 1. 希望容器内的程序能够访问宿主机上的文件 + # 2. 希望容器内程序运行所产生的数据文件能落在宿主机上 + - ./middle/data/elasticsearch:/data:rw + ports: + - 9200:9200 + - 9300:9300 + restart: always + rabbitmq: + image: rabbitmq:3.10.8 + container_name: rabbitmq-service + environment: + RABBITMQ_DEFAULT_USER: root + RABBITMQ_DEFAULT_PASS: 123456 + volumes: + # 1. 希望容器内的程序能够访问宿主机上的文件 + # 2. 希望容器内程序运行所产生的数据文件能落在宿主机上 + - ./middle/data/rabbitmq:/var/lib/rabbitmq:rw + ports: + - 5672:5672 + restart: always + + file_server: + build: ./file + #image: server-user_server + container_name: file_server-service + volumes: + # 1. 希望容器内的程序能够访问宿主机上的文件 + # 2. 希望容器内程序运行所产生的数据文件能落在宿主机上 + # 挂载的信息: entrypoint.sh文件 数据目录(im/logs, im/data), 配置文件 + - ./conf/file_server.conf:/im/conf/file_server.conf + - ./middle/data/logs:/im/logs:rw + - ./middle/data/data:/im/data:rw + - ./entrypoint.sh:/im/bin/entrypoint.sh + ports: + - 10002:10002 + restart: always + entrypoint: + # 跟dockerfile中的cmd比较类似,都是容器启动后的默认操作--替代dockerfile中的cmd + /im/bin/entrypoint.sh -h 10.0.0.235 -p 2379 -c "/im/bin/file_server -flagfile=/im/conf/file_server.conf" + depends_on: + - etcd + friend_server: + build: ./friend + #image: file-server:v1 + container_name: friend_server-service + volumes: + # 1. 希望容器内的程序能够访问宿主机上的文件 + # 2. 希望容器内程序运行所产生的数据文件能落在宿主机上 + # 挂载的信息: entrypoint.sh文件 数据目录(im/logs, im/data), 配置文件 + - ./conf/friend_server.conf:/im/conf/friend_server.conf + - ./middle/data/logs:/im/logs:rw + - ./middle/data/data:/im/data:rw + - ./entrypoint.sh:/im/bin/entrypoint.sh + ports: + - 10006:10006 + restart: always + depends_on: + - etcd + - mysql + - elasticsearch + entrypoint: + # 跟dockerfile中的cmd比较类似,都是容器启动后的默认操作--替代dockerfile中的cmd + /im/bin/entrypoint.sh -h 10.0.0.235 -p 2379,3306,9200 -c "/im/bin/friend_server -flagfile=/im/conf/friend_server.conf" + gateway_server: + build: ./gateway + #image: file-server:v1 + container_name: gateway_server-service + volumes: + # 1. 希望容器内的程序能够访问宿主机上的文件 + # 2. 希望容器内程序运行所产生的数据文件能落在宿主机上 + # 挂载的信息: entrypoint.sh文件 数据目录(im/logs, im/data), 配置文件 + - ./conf/gateway_server.conf:/im/conf/gateway_server.conf + - ./middle/data/logs:/im/logs:rw + - ./middle/data/data:/im/data:rw + - ./entrypoint.sh:/im/bin/entrypoint.sh + ports: + - 9000:9000 + - 9001:9001 + restart: always + depends_on: + - etcd + - redis + entrypoint: + # 跟dockerfile中的cmd比较类似,都是容器启动后的默认操作--替代dockerfile中的cmd + /im/bin/entrypoint.sh -h 10.0.0.235 -p 2379,6379 -c "/im/bin/gateway_server -flagfile=/im/conf/gateway_server.conf" + message_server: + build: ./message + #image: file-server:v1 + container_name: message_server-service + volumes: + # 1. 希望容器内的程序能够访问宿主机上的文件 + # 2. 希望容器内程序运行所产生的数据文件能落在宿主机上 + # 挂载的信息: entrypoint.sh文件 数据目录(im/logs, im/data), 配置文件 + - ./conf/message_server.conf:/im/conf/message_server.conf + - ./middle/data/logs:/im/logs:rw + - ./middle/data/data:/im/data:rw + - ./entrypoint.sh:/im/bin/entrypoint.sh + ports: + - 10005:10005 + restart: always + depends_on: + - etcd + - mysql + - elasticsearch + - rabbitmq + entrypoint: + # 跟dockerfile中的cmd比较类似,都是容器启动后的默认操作--替代dockerfile中的cmd + /im/bin/entrypoint.sh -h 10.0.0.235 -p 2379,3306,9200,5672 -c "/im/bin/message_server -flagfile=/im/conf/message_server.conf" + speech_server: + build: ./speech + #image: file-server:v1 + container_name: speech_server-service + volumes: + # 1. 希望容器内的程序能够访问宿主机上的文件 + # 2. 希望容器内程序运行所产生的数据文件能落在宿主机上 + # 挂载的信息: entrypoint.sh文件 数据目录(im/logs, im/data), 配置文件 + - ./conf/speech_server.conf:/im/conf/speech_server.conf + - ./middle/data/logs:/im/logs:rw + - ./middle/data/data:/im/data:rw + - ./entrypoint.sh:/im/bin/entrypoint.sh + ports: + - 10001:10001 + restart: always + depends_on: + - etcd + entrypoint: + # 跟dockerfile中的cmd比较类似,都是容器启动后的默认操作--替代dockerfile中的cmd + /im/bin/entrypoint.sh -h 10.0.0.235 -p 2379 -c "/im/bin/speech_server -flagfile=/im/conf/speech_server.conf" + transmite_server: + build: ./transmite + #image: file-server:v1 + container_name: transmite_server-service + volumes: + # 1. 希望容器内的程序能够访问宿主机上的文件 + # 2. 希望容器内程序运行所产生的数据文件能落在宿主机上 + # 挂载的信息: entrypoint.sh文件 数据目录(im/logs, im/data), 配置文件 + - ./conf/transmite_server.conf:/im/conf/transmite_server.conf + - ./middle/data/logs:/im/logs:rw + - ./middle/data/data:/im/data:rw + - ./entrypoint.sh:/im/bin/entrypoint.sh + ports: + - 10004:10004 + restart: always + depends_on: + - etcd + - mysql + - rabbitmq + entrypoint: + # 跟dockerfile中的cmd比较类似,都是容器启动后的默认操作--替代dockerfile中的cmd + /im/bin/entrypoint.sh -h 10.0.0.235 -p 2379,3306,5672 -c "/im/bin/transmite_server -flagfile=/im/conf/transmite_server.conf" + user_server: + build: ./user + #image: file-server:v1 + container_name: user_server-service + volumes: + # 1. 希望容器内的程序能够访问宿主机上的文件 + # 2. 希望容器内程序运行所产生的数据文件能落在宿主机上 + # 挂载的信息: entrypoint.sh文件 数据目录(im/logs, im/data), 配置文件 + - ./conf/user_server.conf:/im/conf/user_server.conf + - ./middle/data/logs:/im/logs:rw + - ./middle/data/data:/im/data:rw + - ./entrypoint.sh:/im/bin/entrypoint.sh + ports: + - 10003:10003 + restart: always + depends_on: + - etcd + - mysql + - redis + - elasticsearch + entrypoint: + # 跟dockerfile中的cmd比较类似,都是容器启动后的默认操作--替代dockerfile中的cmd + /im/bin/entrypoint.sh -h 10.0.0.235 -p 2379,3306,5672,9200 -c "/im/bin/user_server -flagfile=/im/conf/user_server.conf" + \ No newline at end of file diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100755 index 0000000..d8eb7d1 --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +#./entrypoint.sh -h 127.0.0.1 -p 3306,2379,6379 -c '/im/bin/file_server -flagfile=./xx.conf' + +# 1. 编写一个端口探测函数,端口连接不上则循环等待 +# wait_for 127.0.0.1 3306 +wait_for() { + while ! nc -z $1 $2 + do + echo "$2 端口连接失败,休眠等待!"; + sleep 1; + done + echo "$1:$2 检测成功!"; +} +# 2. 对脚本运行参数进行解析,获取到ip,port,command +declare ip +declare ports +declare command +while getopts "h:p:c:" arg +do + case $arg in + h) + ip=$OPTARG;; + p) + ports=$OPTARG;; + c) + command=$OPTARG;; + esac +done +# 3. 通过执行脚本进行端口检测 +# ${port //,/ } 针对port中的内容,以空格替换字符串中的, shell中数组--一种以空格间隔的字符串 +for port in ${ports//,/ } +do + wait_for $ip $port +done +# 4. 执行command +# eval 对一个字符串进行二次检测,将其当作命令进行执行 +eval $command \ No newline at end of file diff --git a/file/CMakeLists.txt b/file/CMakeLists.txt new file mode 100644 index 0000000..c80348d --- /dev/null +++ b/file/CMakeLists.txt @@ -0,0 +1,55 @@ +# 1. 添加cmake版本说明 +cmake_minimum_required(VERSION 3.1.3) +# 2. 声明工程名称 +project(file_server) + +set(target "file_server") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") + +# 3. 检测并生成ODB框架代码 +# 1. 添加所需的proto映射代码文件名称 +set(proto_path ${CMAKE_CURRENT_SOURCE_DIR}/../proto) +set(proto_files file.proto base.proto) +# 2. 检测框架代码文件是否已经生成 +set(proto_hxx "") +set(proto_cxx "") +set(proto_srcs "") +foreach(proto_file ${proto_files}) +# 3. 如果没有生成,则预定义生成指令 -- 用于在构建项目之间先生成框架代码 + string(REPLACE ".proto" ".pb.cc" proto_cc ${proto_file}) + string(REPLACE ".proto" ".pb.h" proto_hh ${proto_file}) + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}${proto_cc}) + add_custom_command( + PRE_BUILD + COMMAND protoc + ARGS --cpp_out=${CMAKE_CURRENT_BINARY_DIR} -I ${proto_path} --experimental_allow_proto3_optional ${proto_path}/${proto_file} + DEPENDS ${proto_path}/${proto_file} + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + COMMENT "生成Protobuf框架代码文件:" ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + ) + endif() + list(APPEND proto_srcs ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc}) +endforeach() + +# 4. 获取源码目录下的所有源码文件 +set(src_files "") +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/source src_files) +# 5. 声明目标及依赖 +add_executable(${target} ${src_files} ${proto_srcs}) +# 7. 设置需要连接的库 +target_link_libraries(${target} -lgflags -lspdlog -lfmt -lbrpc -lssl -lcrypto -lprotobuf -lleveldb -letcd-cpp-api -lcpprest -lcurl /usr/lib/x86_64-linux-gnu/libjsoncpp.so.19) + + +set(test_client "file_client") +set(test_files "") +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/test test_files) +add_executable(${test_client} ${test_files} ${proto_srcs}) +target_link_libraries(${test_client} -lgtest -lgflags -lspdlog -lfmt -lbrpc -lssl -lcrypto -lprotobuf -lleveldb -letcd-cpp-api -lcpprest -lcurl /usr/lib/x86_64-linux-gnu/libjsoncpp.so.19) + +# 6. 设置头文件默认搜索路径 +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../common) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third/include) + +#8. 设置安装路径 +INSTALL(TARGETS ${target} ${test_client} RUNTIME DESTINATION bin) \ No newline at end of file diff --git a/file/dockerfile b/file/dockerfile new file mode 100644 index 0000000..26badc4 --- /dev/null +++ b/file/dockerfile @@ -0,0 +1,16 @@ +# 声明基础经镜像来源 +FROM debian:12 + +# 声明工作目录 +WORKDIR /im +RUN mkdir -p /im/logs &&\ + mkdir -p /im/data &&\ + mkdir -p /im/conf &&\ + mkdir -p /im/bin + +# 将可执行程序依赖,拷贝进镜像 +COPY ./build/file_server /im/bin/ +# 将可执行程序文件,拷贝进镜像 +COPY ./depends /lib/x86_64-linux-gnu/ +# 设置容器启动的默认操作 ---运行程序 +CMD /im/bin/file_server -flagfile=/im/conf/file_server.conf \ No newline at end of file diff --git a/file/file_server.conf b/file/file_server.conf new file mode 100644 index 0000000..eab3d48 --- /dev/null +++ b/file/file_server.conf @@ -0,0 +1,11 @@ +-run_mode=true +-log_file=/im/logs/file.log +-log_level=0 +-registry_host=http://10.0.0.235:2379 +-base_service=/service +-instance_name=/file_service/instance +-access_host=10.0.0.235:10002 +-storage_path=/im/data/ +-listen_port=10002 +-rpc_timeout=-1 +-rpc_threads=1 \ No newline at end of file diff --git a/file/source/file_server.cc b/file/source/file_server.cc new file mode 100644 index 0000000..3ea95ac --- /dev/null +++ b/file/source/file_server.cc @@ -0,0 +1,33 @@ +//按照流程完成服务器的搭建 +//1. 参数解析 +//2. 日志初始化 +//3. 构造服务器对象,启动服务器 +#include "file_server.hpp" + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +DEFINE_string(registry_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(instance_name, "/file_service/instance", "当前实例名称"); +DEFINE_string(access_host, "127.0.0.1:10002", "当前实例的外部访问地址"); + +DEFINE_string(storage_path, "./data/", "当前实例的外部访问地址"); + +DEFINE_int32(listen_port, 10002, "Rpc服务器监听端口"); +DEFINE_int32(rpc_timeout, -1, "Rpc调用超时时间"); +DEFINE_int32(rpc_threads, 1, "Rpc的IO线程数量"); + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + bite_im::FileServerBuilder fsb; + fsb.make_rpc_server(FLAGS_listen_port, FLAGS_rpc_timeout, FLAGS_rpc_threads, FLAGS_storage_path); + fsb.make_reg_object(FLAGS_registry_host, FLAGS_base_service + FLAGS_instance_name, FLAGS_access_host); + auto server = fsb.build(); + server->start(); + return 0; +} \ No newline at end of file diff --git a/file/source/file_server.hpp b/file/source/file_server.hpp new file mode 100644 index 0000000..6c7b6a5 --- /dev/null +++ b/file/source/file_server.hpp @@ -0,0 +1,190 @@ +//实现文件存储子服务 +//1. 实现文件rpc服务类 --- 实现rpc调用的业务处理接口 +//2. 实现文件存储子服务的服务器类 +//3. 实现文件存储子服务类的构造者 +#include +#include + +#include "asr.hpp" +#include "etcd.hpp" // 服务注册模块封装 +#include "logger.hpp" // 日志模块封装 +#include "utils.hpp" // uuid生成、文件读写等工具函数 +#include "base.pb.h" +#include "file.pb.h" + +namespace bite_im{ +class FileServiceImpl : public bite_im::FileService { + public: + FileServiceImpl(const std::string &storage_path): + _storage_path(storage_path){ + umask(0); + mkdir(storage_path.c_str(), 0775); + if (_storage_path.back() != '/') _storage_path.push_back('/'); + } + + ~FileServiceImpl(){} + + void GetSingleFile(google::protobuf::RpcController* controller, + const ::bite_im::GetSingleFileReq* request, + ::bite_im::GetSingleFileRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + response->set_request_id(request->request_id()); + //1. 取出请求中的文件ID(起始就是文件名) + std::string fid = request->file_id(); + std::string filename = _storage_path + fid; + //2. 将文件ID作为文件名,读取文件数据 + std::string body; + bool ret = readFile(filename, body); + if (ret == false) { + response->set_success(false); + response->set_errmsg("读取文件数据失败!"); + LOG_ERROR("{} 读取文件数据失败!", request->request_id()); + return; + } + //3. 组织响应 + response->set_success(true); + response->mutable_file_data()->set_file_id(fid); + response->mutable_file_data()->set_file_content(body); + } + + void GetMultiFile(google::protobuf::RpcController* controller, + const ::bite_im::GetMultiFileReq* request, + ::bite_im::GetMultiFileRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + response->set_request_id(request->request_id()); + // 循环取出请求中的文件ID,读取文件数据进行填充 + for (int i = 0; i < request->file_id_list_size(); i++) { + std::string fid = request->file_id_list(i); + std::string filename = _storage_path + fid; + std::string body; + bool ret = readFile(filename, body); + if (ret == false) { + response->set_success(false); + response->set_errmsg("读取文件数据失败!"); + LOG_ERROR("{} 读取文件数据失败!", request->request_id()); + return; + } + FileDownloadData data; + data.set_file_id(fid); + data.set_file_content(body); + response->mutable_file_data()->insert({fid, data}); + } + response->set_success(true); + } + + void PutSingleFile(google::protobuf::RpcController* controller, + const ::bite_im::PutSingleFileReq* request, + ::bite_im::PutSingleFileRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + response->set_request_id(request->request_id()); + //1. 为文件生成一个唯一uudi作为文件名 以及 文件ID + std::string fid = uuid(); + std::string filename = _storage_path + fid; + //2. 取出请求中的文件数据,进行文件数据写入 + bool ret = writeFile(filename, request->file_data().file_content()); + if (ret == false) { + response->set_success(false); + response->set_errmsg("读取文件数据失败!"); + LOG_ERROR("{} 写入文件数据失败!", request->request_id()); + return; + } + //3. 组织响应 + response->set_success(true); + response->mutable_file_info()->set_file_id(fid); + response->mutable_file_info()->set_file_size(request->file_data().file_size()); + response->mutable_file_info()->set_file_name(request->file_data().file_name()); + } + + void PutMultiFile(google::protobuf::RpcController* controller, + const ::bite_im::PutMultiFileReq* request, + ::bite_im::PutMultiFileRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + response->set_request_id(request->request_id()); + for (int i = 0; i < request->file_data_size(); i++) { + std::string fid = uuid(); + std::string filename = _storage_path + fid; + bool ret = writeFile(filename, request->file_data(i).file_content()); + if (ret == false) { + response->set_success(false); + response->set_errmsg("读取文件数据失败!"); + LOG_ERROR("{} 写入文件数据失败!", request->request_id()); + return; + } + bite_im::FileMessageInfo *info = response->add_file_info(); + info->set_file_id(fid); + info->set_file_size(request->file_data(i).file_size()); + info->set_file_name(request->file_data(i).file_name()); + } + response->set_success(true); + } + private: + std::string _storage_path; +}; + +class FileServer { + public: + using ptr = std::shared_ptr; + FileServer(const Registry::ptr ®_client, + const std::shared_ptr &server): + _reg_client(reg_client), + _rpc_server(server){} + ~FileServer(){} + //搭建RPC服务器,并启动服务器 + void start() { + _rpc_server->RunUntilAskedToQuit(); + } + private: + Registry::ptr _reg_client; + std::shared_ptr _rpc_server; +}; + +class FileServerBuilder { + public: + //用于构造服务注册客户端对象 + void make_reg_object(const std::string ®_host, + const std::string &service_name, + const std::string &access_host) { + _reg_client = std::make_shared(reg_host); + _reg_client->registry(service_name, access_host); + } + //构造RPC服务器对象 + void make_rpc_server(uint16_t port, int32_t timeout, + uint8_t num_threads, const std::string &path = "./data/") { + _rpc_server = std::make_shared(); + FileServiceImpl *file_service = new FileServiceImpl(path); + int ret = _rpc_server->AddService(file_service, + brpc::ServiceOwnership::SERVER_OWNS_SERVICE); + if (ret == -1) { + LOG_ERROR("添加Rpc服务失败!"); + abort(); + } + brpc::ServerOptions options; + options.idle_timeout_sec = timeout; + options.num_threads = num_threads; + ret = _rpc_server->Start(port, &options); + if (ret == -1) { + LOG_ERROR("服务启动失败!"); + abort(); + } + } + FileServer::ptr build() { + if (!_reg_client) { + LOG_ERROR("还未初始化服务注册模块!"); + abort(); + } + if (!_rpc_server) { + LOG_ERROR("还未初始化RPC服务器模块!"); + abort(); + } + FileServer::ptr server = std::make_shared(_reg_client, _rpc_server); + return server; + } + private: + Registry::ptr _reg_client; + std::shared_ptr _rpc_server; +}; +} \ No newline at end of file diff --git a/file/test/file_client.cc b/file/test/file_client.cc new file mode 100644 index 0000000..f6e90bf --- /dev/null +++ b/file/test/file_client.cc @@ -0,0 +1,152 @@ +//编写一个file客户端程序,对文件存储子服务进行单元测试 +// 1. 封装四个接口进行rpc调用,实现对于四个业务接口的测试 +#include +#include +#include +#include "etcd.hpp" +#include "channel.hpp" +#include "logger.hpp" +#include "file.pb.h" +#include "base.pb.h" +#include "utils.hpp" + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +DEFINE_string(etcd_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(file_service, "/service/file_service", "服务监控根目录"); + + +bite_im::ServiceChannel::ChannelPtr channel; +std::string single_file_id; + + +TEST(put_test, single_file) { + //1. 读取当前目录下的指定文件数据 + std::string body; + ASSERT_TRUE(bite_im::readFile("./Makefile", body)); + //2. 实例化rpc调用客户端对象,发起rpc调用 + bite_im::FileService_Stub stub(channel.get()); + + bite_im::PutSingleFileReq req; + req.set_request_id("1111"); + req.mutable_file_data()->set_file_name("Makefile"); + req.mutable_file_data()->set_file_size(body.size()); + req.mutable_file_data()->set_file_content(body); + + brpc::Controller *cntl = new brpc::Controller(); + bite_im::PutSingleFileRsp *rsp = new bite_im::PutSingleFileRsp(); + stub.PutSingleFile(cntl, &req, rsp, nullptr); + ASSERT_FALSE(cntl->Failed()); + //3. 检测返回值中上传是否成功 + ASSERT_TRUE(rsp->success()); + ASSERT_EQ(rsp->file_info().file_size(), body.size()); + ASSERT_EQ(rsp->file_info().file_name(), "Makefile"); + single_file_id = rsp->file_info().file_id(); + LOG_DEBUG("文件ID:{}", rsp->file_info().file_id()); +} + +TEST(get_test, single_file) { + //先发起Rpc调用,进行文件下载 + bite_im::FileService_Stub stub(channel.get()); + bite_im::GetSingleFileReq req; + bite_im::GetSingleFileRsp *rsp; + req.set_request_id("2222"); + req.set_file_id(single_file_id); + + brpc::Controller *cntl = new brpc::Controller(); + rsp = new bite_im::GetSingleFileRsp(); + stub.GetSingleFile(cntl, &req, rsp, nullptr); + ASSERT_FALSE(cntl->Failed()); + ASSERT_TRUE(rsp->success()); + //将文件数据,存储到文件中 + ASSERT_EQ(single_file_id, rsp->file_data().file_id()); + bite_im::writeFile("make_file_download", rsp->file_data().file_content()); +} + +std::vector multi_file_id; + +TEST(put_test, multi_file) { + //1. 读取当前目录下的指定文件数据 + std::string body1; + ASSERT_TRUE(bite_im::readFile("./base.pb.h", body1)); + std::string body2; + ASSERT_TRUE(bite_im::readFile("./file.pb.h", body2)); + //2. 实例化rpc调用客户端对象,发起rpc调用 + bite_im::FileService_Stub stub(channel.get()); + + bite_im::PutMultiFileReq req; + req.set_request_id("3333"); + + auto file_data = req.add_file_data(); + file_data->set_file_name("base.pb.h"); + file_data->set_file_size(body1.size()); + file_data->set_file_content(body1); + + file_data = req.add_file_data(); + file_data->set_file_name("file.pb.h"); + file_data->set_file_size(body2.size()); + file_data->set_file_content(body2); + + brpc::Controller *cntl = new brpc::Controller(); + bite_im::PutMultiFileRsp *rsp = new bite_im::PutMultiFileRsp(); + stub.PutMultiFile(cntl, &req, rsp, nullptr); + ASSERT_FALSE(cntl->Failed()); + //3. 检测返回值中上传是否成功 + ASSERT_TRUE(rsp->success()); + for (int i = 0; i < rsp->file_info_size(); i++){ + multi_file_id.push_back(rsp->file_info(i).file_id()); + LOG_DEBUG("文件ID:{}", multi_file_id[i]); + } +} + +TEST(get_test, multi_file) { + //先发起Rpc调用,进行文件下载 + bite_im::FileService_Stub stub(channel.get()); + bite_im::GetMultiFileReq req; + bite_im::GetMultiFileRsp *rsp; + req.set_request_id("4444"); + req.add_file_id_list(multi_file_id[0]); + req.add_file_id_list(multi_file_id[1]); + + brpc::Controller *cntl = new brpc::Controller(); + rsp = new bite_im::GetMultiFileRsp(); + stub.GetMultiFile(cntl, &req, rsp, nullptr); + ASSERT_FALSE(cntl->Failed()); + ASSERT_TRUE(rsp->success()); + //将文件数据,存储到文件中 + ASSERT_TRUE(rsp->file_data().find(multi_file_id[0]) != rsp->file_data().end()); + ASSERT_TRUE(rsp->file_data().find(multi_file_id[1]) != rsp->file_data().end()); + auto map = rsp->file_data(); + auto file_data1 = map[multi_file_id[0]]; + bite_im::writeFile("base_download_file1",file_data1.file_content()); + auto file_data2 = map[multi_file_id[1]]; + bite_im::writeFile("file_download_file2", file_data2.file_content()); +} + +int main(int argc, char *argv[]) +{ + testing::InitGoogleTest(&argc, argv); + google::ParseCommandLineFlags(&argc, &argv, true); + + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + //1. 先构造Rpc信道管理对象 + auto sm = std::make_shared(); + sm->declared(FLAGS_file_service); + auto put_cb = std::bind(&bite_im::ServiceManager::onServiceOnline, sm.get(), std::placeholders::_1, std::placeholders::_2); + auto del_cb = std::bind(&bite_im::ServiceManager::onServiceOffline, sm.get(), std::placeholders::_1, std::placeholders::_2); + //2. 构造服务发现对象 + bite_im::Discovery::ptr dclient = std::make_shared(FLAGS_etcd_host, FLAGS_base_service, put_cb, del_cb); + + //3. 通过Rpc信道管理对象,获取提供Echo服务的信道 + channel = sm->choose(FLAGS_file_service); + if (!channel) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + return -1; + } + + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/friend/CMakeLists.txt b/friend/CMakeLists.txt new file mode 100644 index 0000000..533771b --- /dev/null +++ b/friend/CMakeLists.txt @@ -0,0 +1,88 @@ +# 1. 添加cmake版本说明 +cmake_minimum_required(VERSION 3.1.3) +# 2. 声明工程名称 +project(friend_server) + +set(target "friend_server") + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") + +# 3. 检测并生成ODB框架代码 +# 1. 添加所需的proto映射代码文件名称 +set(proto_path ${CMAKE_CURRENT_SOURCE_DIR}/../proto) +set(proto_files base.proto user.proto message.proto friend.proto) +# 2. 检测框架代码文件是否已经生成 +set(proto_hxx "") +set(proto_cxx "") +set(proto_srcs "") +foreach(proto_file ${proto_files}) +# 3. 如果没有生成,则预定义生成指令 -- 用于在构建项目之间先生成框架代码 + string(REPLACE ".proto" ".pb.cc" proto_cc ${proto_file}) + string(REPLACE ".proto" ".pb.h" proto_hh ${proto_file}) + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}${proto_cc}) + add_custom_command( + PRE_BUILD + COMMAND protoc + ARGS --cpp_out=${CMAKE_CURRENT_BINARY_DIR} -I ${proto_path} --experimental_allow_proto3_optional ${proto_path}/${proto_file} + DEPENDS ${proto_path}/${proto_file} + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + COMMENT "生成Protobuf框架代码文件:" ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + ) + endif() + list(APPEND proto_srcs ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc}) +endforeach() + +# 3. 检测并生成ODB框架代码 +# 1. 添加所需的odb映射代码文件名称 +set(odb_path ${CMAKE_CURRENT_SOURCE_DIR}/../odb) +set(odb_files chat_session_member.hxx chat_session.hxx friend_apply.hxx relation.hxx) +# 2. 检测框架代码文件是否已经生成 +set(odb_hxx "") +set(odb_cxx "") +set(odb_srcs "") +foreach(odb_file ${odb_files}) +# 3. 如果没有生成,则预定义生成指令 -- 用于在构建项目之间先生成框架代码 + string(REPLACE ".hxx" "-odb.hxx" odb_hxx ${odb_file}) + string(REPLACE ".hxx" "-odb.cxx" odb_cxx ${odb_file}) + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}${odb_cxx}) + add_custom_command( + PRE_BUILD + COMMAND odb + ARGS -d mysql --std c++11 --generate-query --generate-schema --profile boost/date-time ${odb_path}/${odb_file} + DEPENDS ${odb_path}/${odb_file} + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${odb_cxx} + COMMENT "生成ODB框架代码文件:" ${CMAKE_CURRENT_BINARY_DIR}/${odb_cxx} + ) + endif() +# 4. 将所有生成的框架源码文件名称保存起来 student-odb.cxx classes-odb.cxx + list(APPEND odb_srcs ${CMAKE_CURRENT_BINARY_DIR}/${odb_cxx}) +endforeach() + +# 4. 获取源码目录下的所有源码文件 +set(src_files "") +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/source src_files) +# 5. 声明目标及依赖 +add_executable(${target} ${src_files} ${proto_srcs} ${odb_srcs}) +# 7. 设置需要连接的库 +target_link_libraries(${target} -lgflags + -lspdlog -lfmt -lbrpc -lssl -lcrypto + -lprotobuf -lleveldb -letcd-cpp-api + -lcpprest -lcurl -lodb-mysql -lodb -lodb-boost + /usr/lib/x86_64-linux-gnu/libjsoncpp.so.19 + -lcpr -lelasticlient) + + +set(test_client "friend_client") +set(test_files "") +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/test test_files) +add_executable(${test_client} ${test_files} ${proto_srcs}) +target_link_libraries(${test_client} -pthread -lgtest -lgflags -lspdlog -lfmt -lbrpc -lssl -lcrypto -lprotobuf -lleveldb -letcd-cpp-api -lcpprest -lcurl /usr/lib/x86_64-linux-gnu/libjsoncpp.so.19) + +# 6. 设置头文件默认搜索路径 +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../common) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../odb) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third/include) + +#8. 设置安装路径 +INSTALL(TARGETS ${target} ${test_client} RUNTIME DESTINATION bin) \ No newline at end of file diff --git a/friend/dockerfile b/friend/dockerfile new file mode 100644 index 0000000..cfe0cec --- /dev/null +++ b/friend/dockerfile @@ -0,0 +1,16 @@ +# 声明基础经镜像来源 +FROM debian:12 + +# 声明工作目录 +WORKDIR /im +RUN mkdir -p /im/logs &&\ + mkdir -p /im/data &&\ + mkdir -p /im/conf &&\ + mkdir -p /im/bin + +# 将可执行程序依赖,拷贝进镜像 +COPY ./build/friend_server /im/bin/ +# 将可执行程序文件,拷贝进镜像 +COPY ./depends /lib/x86_64-linux-gnu/ +# 设置容器启动的默认操作 ---运行程序 +CMD /im/bin/friend_server -flagfile=/im/conf/friend_server.conf \ No newline at end of file diff --git a/friend/friend_server.conf b/friend/friend_server.conf new file mode 100644 index 0000000..b911ffc --- /dev/null +++ b/friend/friend_server.conf @@ -0,0 +1,20 @@ +-run_mode=true +-log_file=/im/logs/friend.log +-log_level=0 +-registry_host=http://10.0.0.235:2379 +-instance_name=/friend_service/instance +-access_host=10.0.0.235:10006 +-listen_port=10006 +-rpc_timeout=-1 +-rpc_threads=1 +-base_service=/service +-user_service=/service/user_service +-message_service=/service/message_service +-es_host=http://10.0.0.235:9200/ +-mysql_host=10.0.0.235 +-mysql_user=root +-mysql_pswd=123456 +-mysql_db=bite_im +-mysql_cset=utf8 +-mysql_port=0 +-mysql_pool_count=4 \ No newline at end of file diff --git a/friend/source/friend_server.cc b/friend/source/friend_server.cc new file mode 100644 index 0000000..e8ff01b --- /dev/null +++ b/friend/source/friend_server.cc @@ -0,0 +1,49 @@ +//主要实现语音识别子服务的服务器的搭建 +#include "friend_server.hpp" + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +DEFINE_string(registry_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(instance_name, "/friend_service/instance", "当前实例名称"); +DEFINE_string(access_host, "127.0.0.1:10006", "当前实例的外部访问地址"); + +DEFINE_int32(listen_port, 10006, "Rpc服务器监听端口"); +DEFINE_int32(rpc_timeout, -1, "Rpc调用超时时间"); +DEFINE_int32(rpc_threads, 1, "Rpc的IO线程数量"); + + +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(user_service, "/service/user_service", "用户管理子服务名称"); +DEFINE_string(message_service, "/service/message_service", "消息存储子服务名称"); + +DEFINE_string(es_host, "http://127.0.0.1:9200/", "ES搜索引擎服务器URL"); + +DEFINE_string(mysql_host, "127.0.0.1", "Mysql服务器访问地址"); +DEFINE_string(mysql_user, "root", "Mysql服务器访问用户名"); +DEFINE_string(mysql_pswd, "123456", "Mysql服务器访问密码"); +DEFINE_string(mysql_db, "bite_im", "Mysql默认库名称"); +DEFINE_string(mysql_cset, "utf8", "Mysql客户端字符集"); +DEFINE_int32(mysql_port, 0, "Mysql服务器访问端口"); +DEFINE_int32(mysql_pool_count, 4, "Mysql连接池最大连接数量"); + + + + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + bite_im::FriendServerBuilder fsb; + fsb.make_es_object({FLAGS_es_host}); + fsb.make_mysql_object(FLAGS_mysql_user, FLAGS_mysql_pswd, FLAGS_mysql_host, + FLAGS_mysql_db, FLAGS_mysql_cset, FLAGS_mysql_port, FLAGS_mysql_pool_count); + fsb.make_discovery_object(FLAGS_registry_host, FLAGS_base_service, FLAGS_user_service, FLAGS_message_service); + fsb.make_rpc_server(FLAGS_listen_port, FLAGS_rpc_timeout, FLAGS_rpc_threads); + fsb.make_registry_object(FLAGS_registry_host, FLAGS_base_service + FLAGS_instance_name, FLAGS_access_host); + auto server = fsb.build(); + server->start(); + return 0; +} \ No newline at end of file diff --git a/friend/source/friend_server.hpp b/friend/source/friend_server.hpp new file mode 100644 index 0000000..ee32b97 --- /dev/null +++ b/friend/source/friend_server.hpp @@ -0,0 +1,637 @@ +//实现语音识别子服务 +#include +#include + +#include "data_es.hpp" // es数据管理客户端封装 +#include "mysql_chat_session_member.hpp" // mysql数据管理客户端封装 +#include "mysql_chat_session.hpp" // mysql数据管理客户端封装 +#include "mysql_relation.hpp" // mysql数据管理客户端封装 +#include "mysql_apply.hpp" // mysql数据管理客户端封装 +#include "etcd.hpp" // 服务注册模块封装 +#include "logger.hpp" // 日志模块封装 +#include "utils.hpp" // 基础工具接口 +#include "channel.hpp" // 信道管理模块封装 + + +#include "friend.pb.h" // protobuf框架代码 +#include "base.pb.h" // protobuf框架代码 +#include "user.pb.h" // protobuf框架代码 +#include "message.pb.h" // protobuf框架代码 + +namespace bite_im{ +class FriendServiceImpl : public bite_im::FriendService { + public: + FriendServiceImpl( + const std::shared_ptr &es_client, + const std::shared_ptr &mysql_client, + const ServiceManager::ptr &channel_manager, + const std::string &user_service_name, + const std::string &message_service_name) : + _es_user(std::make_shared(es_client)), + _mysql_apply(std::make_shared(mysql_client)), + _mysql_chat_session(std::make_shared(mysql_client)), + _mysql_chat_session_member(std::make_shared(mysql_client)), + _mysql_relation(std::make_shared(mysql_client)), + _user_service_name(user_service_name), + _message_service_name(message_service_name), + _mm_channels(channel_manager){} + ~FriendServiceImpl(){} + virtual void GetFriendList(::google::protobuf::RpcController* controller, + const ::bite_im::GetFriendListReq* request, + ::bite_im::GetFriendListRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + //1. 定义错误回调 + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + + //1. 提取请求中的关键要素:用户ID + std::string rid = request->request_id(); + std::string uid = request->user_id(); + //2. 从数据库中查询获取用户的好友ID + auto friend_id_lists = _mysql_relation->friends(uid); + std::unordered_set user_id_lists; + for (auto &id : friend_id_lists) { + user_id_lists.insert(id); + } + //3. 从用户子服务批量获取用户信息 + std::unordered_map user_list; + bool ret = GetUserInfo(rid, user_id_lists, user_list); + if (ret == false) { + LOG_ERROR("{} - 批量获取用户信息失败!", rid); + return err_response(rid, "批量获取用户信息失败!"); + } + //4. 组织响应 + response->set_request_id(rid); + response->set_success(true); + for (const auto & user_it : user_list) { + auto user_info = response->add_friend_list(); + user_info->CopyFrom(user_it.second); + } + } + virtual void FriendRemove(::google::protobuf::RpcController* controller, + const ::bite_im::FriendRemoveReq* request, + ::bite_im::FriendRemoveRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + //1. 定义错误回调 + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //1. 提取关键要素:当前用户ID,要删除的好友ID + std::string rid = request->request_id(); + std::string uid = request->user_id(); + std::string pid = request->peer_id(); + //2. 从好友关系表中删除好友关系信息 + bool ret = _mysql_relation->remove(uid, pid); + if (ret == false) { + LOG_ERROR("{} - 从数据库删除好友信息失败!", rid); + return err_response(rid, "从数据库删除好友信息失败!"); + } + //3. 从会话信息表中,删除对应的聊天会话 -- 同时删除会话成员表中的成员信息 + ret = _mysql_chat_session->remove(uid, pid); + if (ret == false) { + LOG_ERROR("{}- 从数据库删除好友会话信息失败!", rid); + return err_response(rid, "从数据库删除好友会话信息失败!"); + } + //4. 组织响应 + response->set_request_id(rid); + response->set_success(true); + } + virtual void FriendAdd(::google::protobuf::RpcController* controller, + const ::bite_im::FriendAddReq* request, + ::bite_im::FriendAddRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + //1. 定义错误回调 + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //1. 提取请求中的关键要素:申请人用户ID; 被申请人用户ID + std::string rid = request->request_id(); + std::string uid = request->user_id(); + std::string pid = request->respondent_id(); + //2. 判断两人是否已经是好友 + bool ret = _mysql_relation->exists(uid, pid); + if (ret == true) { + LOG_ERROR("{}- 申请好友失败-两者{}-{}已经是好友关系", rid, uid, pid); + return err_response(rid, "两者已经是好友关系!"); + } + //3. 当前是否已经申请过好友 + ret = _mysql_apply->exists(uid, pid); + if (ret == true) { + LOG_ERROR("{}- 申请好友失败-已经申请过对方好友!", rid, uid, pid); + return err_response(rid, "已经申请过对方好友!"); + } + //4. 向好友申请表中,新增申请信息 + std::string eid = uuid(); + FriendApply ev(eid, uid, pid); + ret = _mysql_apply->insert(ev); + if (ret == false) { + LOG_ERROR("{} - 向数据库新增好友申请事件失败!", rid); + return err_response(rid, "向数据库新增好友申请事件失败!"); + } + //3. 组织响应 + response->set_request_id(rid); + response->set_success(true); + response->set_notify_event_id(eid); + } + virtual void FriendAddProcess(::google::protobuf::RpcController* controller, + const ::bite_im::FriendAddProcessReq* request, + ::bite_im::FriendAddProcessRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + //1. 定义错误回调 + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //1. 提取请求中的关键要素:申请人用户ID;被申请人用户ID;处理结果;事件ID + std::string rid = request->request_id(); + std::string eid = request->notify_event_id(); + std::string uid = request->user_id(); //被申请人 + std::string pid = request->apply_user_id();//申请人 + bool agree = request->agree(); + //2. 判断有没有该申请事件 + bool ret = _mysql_apply->exists(pid, uid); + if (ret == false) { + LOG_ERROR("{}- 没有找到{}-{}对应的好友申请事件!", rid, pid, uid); + return err_response(rid, "没有找到对应的好友申请事件!"); + } + //3. 如果有: 可以处理; --- 删除申请事件--事件已经处理完毕 + ret = _mysql_apply->remove(pid, uid); + if (ret == false) { + LOG_ERROR("{}- 从数据库删除申请事件 {}-{} 失败!", rid, pid, uid); + return err_response(rid, "从数据库删除申请事件失败!"); + } + //4. 如果处理结果是同意:向数据库新增好友关系信息;新增单聊会话信息及会话成员 + std::string cssid; + if (agree == true) { + ret = _mysql_relation->insert(uid, pid); + if (ret == false) { + LOG_ERROR("{}- 新增好友关系信息-{}-{}!", rid, uid, pid); + return err_response(rid, "新增好友关系信息!"); + } + cssid = uuid(); + ChatSession cs(cssid, "", ChatSessionType::SINGLE); + ret = _mysql_chat_session->insert(cs); + if (ret == false) { + LOG_ERROR("{}- 新增单聊会话信息-{}!", rid, cssid); + return err_response(rid, "新增单聊会话信息失败!"); + } + ChatSessionMember csm1(cssid, uid); + ChatSessionMember csm2(cssid, pid); + std::vector mlist = {csm1, csm2}; + ret = _mysql_chat_session_member->append(mlist); + if (ret == false) { + LOG_ERROR("{}- 没有找到{}-{}对应的好友申请事件!", rid, pid, uid); + return err_response(rid, "没有找到对应的好友申请事件!"); + } + } + //5. 组织响应 + response->set_request_id(rid); + response->set_success(true); + response->set_new_session_id(cssid); + } + virtual void FriendSearch(::google::protobuf::RpcController* controller, + const ::bite_im::FriendSearchReq* request, + ::bite_im::FriendSearchRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + //1. 定义错误回调 + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //1. 提取请求中的关键要素:搜索关键字(可能是用户ID,可能是手机号,可能是昵称的一部分) + std::string rid = request->request_id(); + std::string uid = request->user_id(); + std::string skey = request->search_key(); + LOG_DEBUG("{} 好友搜索 : {}", uid, skey); + //2. 根据用户ID,获取用户的好友ID列表 + auto friend_id_lists = _mysql_relation->friends(uid); + //3. 从ES搜索引擎进行用户信息搜索 --- 过滤掉当前的好友 + std::unordered_set user_id_lists; + friend_id_lists.push_back(uid);// 把自己也过滤掉 + auto search_res = _es_user->search(skey, friend_id_lists); + for (auto &it : search_res) { + user_id_lists.insert(it.user_id()); + } + //4. 根据获取到的用户ID, 从用户子服务器进行批量用户信息获取 + std::unordered_map user_list; + bool ret = GetUserInfo(rid, user_id_lists, user_list); + if (ret == false) { + LOG_ERROR("{} - 批量获取用户信息失败!", rid); + return err_response(rid, "批量获取用户信息失败!"); + } + //5. 组织响应 + response->set_request_id(rid); + response->set_success(true); + for (const auto & user_it : user_list) { + auto user_info = response->add_user_info(); + user_info->CopyFrom(user_it.second); + } + } + virtual void GetPendingFriendEventList(::google::protobuf::RpcController* controller, + const ::bite_im::GetPendingFriendEventListReq* request, + ::bite_im::GetPendingFriendEventListRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + //1. 定义错误回调 + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //1. 提取关键要素:当前用户ID + std::string rid = request->request_id(); + std::string uid = request->user_id(); + //2. 从数据库获取待处理的申请事件信息 --- 申请人用户ID列表 + auto res = _mysql_apply->applyUsers(uid); + std::unordered_set user_id_lists; + for (auto &id : res) { + user_id_lists.insert(id); + } + //3. 批量获取申请人用户信息、 + std::unordered_map user_list; + bool ret = GetUserInfo(rid, user_id_lists, user_list); + if (ret == false) { + LOG_ERROR("{} - 批量获取用户信息失败!", rid); + return err_response(rid, "批量获取用户信息失败!"); + } + //4. 组织响应 + response->set_request_id(rid); + response->set_success(true); + for (const auto & user_it : user_list) { + auto ev = response->add_event(); + ev->mutable_sender()->CopyFrom(user_it.second); + } + } + virtual void GetChatSessionList(::google::protobuf::RpcController* controller, + const ::bite_im::GetChatSessionListReq* request, + ::bite_im::GetChatSessionListRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + //1. 定义错误回调 + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //获取聊天会话的作用:一个用户登录成功后,能够展示自己的历史聊天信息 + //1. 提取请求中的关键要素:当前请求用户ID + std::string rid = request->request_id(); + std::string uid = request->user_id(); + //2. 从数据库中查询出用户的单聊会话列表 + auto sf_list = _mysql_chat_session->singleChatSession(uid); + // 1. 从单聊会话列表中,取出所有的好友ID,从用户子服务获取用户信息 + std::unordered_set users_id_list; + for (const auto &f : sf_list) { + users_id_list.insert(f.friend_id); + } + std::unordered_map user_list; + bool ret = GetUserInfo(rid, users_id_list, user_list); + if (ret == false) { + LOG_ERROR("{} - 批量获取用户信息失败!", rid); + return err_response(rid, "批量获取用户信息失败!"); + } + // 2. 设置响应会话信息:会话名称就是好友名称;会话头像就是好友头像 + //3. 从数据库中查询出用户的群聊会话列表 + auto gc_list = _mysql_chat_session->groupChatSession(uid); + + //4. 根据所有的会话ID,从消息存储子服务获取会话最后一条消息 + //5. 组织响应 + for (const auto &f : sf_list) { + auto chat_session_info = response->add_chat_session_info_list(); + chat_session_info->set_single_chat_friend_id(f.friend_id); + chat_session_info->set_chat_session_id(f.chat_session_id); + chat_session_info->set_chat_session_name(user_list[f.friend_id].nickname()); + chat_session_info->set_avatar(user_list[f.friend_id].avatar()); + MessageInfo msg; + ret = GetRecentMsg(rid, f.chat_session_id, msg); + if (ret == false) {continue;} + chat_session_info->mutable_prev_message()->CopyFrom(msg); + } + for (const auto &f : gc_list) { + auto chat_session_info = response->add_chat_session_info_list(); + chat_session_info->set_chat_session_id(f.chat_session_id); + chat_session_info->set_chat_session_name(f.chat_session_name); + MessageInfo msg; + ret = GetRecentMsg(rid, f.chat_session_id, msg); + if (ret == false) { continue; } + chat_session_info->mutable_prev_message()->CopyFrom(msg); + } + response->set_request_id(rid); + response->set_success(true); + } + virtual void ChatSessionCreate(::google::protobuf::RpcController* controller, + const ::bite_im::ChatSessionCreateReq* request, + ::bite_im::ChatSessionCreateRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + //1. 定义错误回调 + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //创建会话,其实针对的是用户要创建一个群聊会话 + //1. 提取请求关键要素:会话名称,会话成员 + std::string rid = request->request_id(); + std::string uid = request->user_id(); + std::string cssname = request->chat_session_name(); + + //2. 生成会话ID,向数据库添加会话信息,添加会话成员信息 + std::string cssid = uuid(); + ChatSession cs(cssid, cssname, ChatSessionType::GROUP); + bool ret = _mysql_chat_session->insert(cs); + if (ret == false) { + LOG_ERROR("{} - 向数据库添加会话信息失败: {}", rid, cssname); + return err_response(rid, "向数据库添加会话信息失败!"); + } + std::vector member_list; + for (int i = 0; i < request->member_id_list_size(); i++) { + ChatSessionMember csm(cssid, request->member_id_list(i)); + member_list.push_back(csm); + } + ret = _mysql_chat_session_member->append(member_list); + if (ret == false) { + LOG_ERROR("{} - 向数据库添加会话成员信息失败: {}", rid, cssname); + return err_response(rid, "向数据库添加会话成员信息失败!"); + } + //3. 组织响应---组织会话信息 + response->set_request_id(rid); + response->set_success(true); + response->mutable_chat_session_info()->set_chat_session_id(cssid); + response->mutable_chat_session_info()->set_chat_session_name(cssname); + } + virtual void GetChatSessionMember(::google::protobuf::RpcController* controller, + const ::bite_im::GetChatSessionMemberReq* request, + ::bite_im::GetChatSessionMemberRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + //1. 定义错误回调 + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //用于用户查看群聊成员信息的时候:进行成员信息展示 + //1. 提取关键要素:聊天会话ID + std::string rid = request->request_id(); + std::string uid = request->user_id(); + std::string cssid = request->chat_session_id(); + //2. 从数据库获取会话成员ID列表 + auto member_id_lists = _mysql_chat_session_member->members(cssid); + std::unordered_set uid_list; + for (const auto &id : member_id_lists) { + uid_list.insert(id); + } + //3. 从用户子服务批量获取用户信息 + std::unordered_map user_list; + bool ret = GetUserInfo(rid, uid_list, user_list); + if (ret == false) { + LOG_ERROR("{} - 从用户子服务获取用户信息失败!", rid); + return err_response(rid, "从用户子服务获取用户信息失败!"); + } + //4. 组织响应 + response->set_request_id(rid); + response->set_success(true); + for (const auto &uit : user_list) { + auto user_info = response->add_member_info_list(); + user_info->CopyFrom(uit.second); + } + } + private: + bool GetRecentMsg(const std::string &rid, + const std::string &cssid, MessageInfo &msg) { + auto channel = _mm_channels->choose(_message_service_name); + if (!channel) { + LOG_ERROR("{} - 获取消息子服务信道失败!!", rid); + return false; + } + GetRecentMsgReq req; + GetRecentMsgRsp rsp; + req.set_request_id(rid); + req.set_chat_session_id(cssid); + req.set_msg_count(1); + brpc::Controller cntl; + bite_im::MsgStorageService_Stub stub(channel.get()); + stub.GetRecentMsg(&cntl, &req, &rsp, nullptr); + if (cntl.Failed() == true) { + LOG_ERROR("{} - 消息存储子服务调用失败: {}", rid, cntl.ErrorText()); + return false; + } + if ( rsp.success() == false) { + LOG_ERROR("{} - 获取会话 {} 最近消息失败: {}", rid, cssid, rsp.errmsg()); + return false; + } + if (rsp.msg_list_size() > 0) { + msg.CopyFrom(rsp.msg_list(0)); + return true; + } + return false; + } + bool GetUserInfo(const std::string &rid, + const std::unordered_set &uid_list, + std::unordered_map &user_list) { + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{} - 获取用户子服务信道失败!!", rid); + return false; + } + GetMultiUserInfoReq req; + GetMultiUserInfoRsp rsp; + req.set_request_id(rid); + for (auto &id : uid_list) { + req.add_users_id(id); + } + brpc::Controller cntl; + bite_im::UserService_Stub stub(channel.get()); + stub.GetMultiUserInfo(&cntl, &req, &rsp, nullptr); + if (cntl.Failed() == true) { + LOG_ERROR("{} - 用户子服务调用失败: {}", rid, cntl.ErrorText()); + return false; + } + if ( rsp.success() == false) { + LOG_ERROR("{} - 批量获取用户信息失败: {}", rid, rsp.errmsg()); + return false; + } + for (const auto & user_it : rsp.users_info()) { + user_list.insert(std::make_pair(user_it.first, user_it.second)); + } + return true; + } + + private: + ESUser::ptr _es_user; + + FriendApplyTable::ptr _mysql_apply; + ChatSessionTable::ptr _mysql_chat_session; + ChatSessionMemeberTable::ptr _mysql_chat_session_member; + RelationTable::ptr _mysql_relation; + + //这边是rpc调用客户端相关对象 + std::string _user_service_name; + std::string _message_service_name; + ServiceManager::ptr _mm_channels; +}; + +class FriendServer { + public: + using ptr = std::shared_ptr; + FriendServer(const Discovery::ptr service_discoverer, + const Registry::ptr ®_client, + const std::shared_ptr &es_client, + const std::shared_ptr &mysql_client, + const std::shared_ptr &server): + _service_discoverer(service_discoverer), + _registry_client(reg_client), + _es_client(es_client), + _mysql_client(mysql_client), + _rpc_server(server){} + ~FriendServer(){} + //搭建RPC服务器,并启动服务器 + void start() { + _rpc_server->RunUntilAskedToQuit(); + } + private: + Discovery::ptr _service_discoverer; + Registry::ptr _registry_client; + std::shared_ptr _es_client; + std::shared_ptr _mysql_client; + std::shared_ptr _rpc_server; +}; + +class FriendServerBuilder { + public: + //构造es客户端对象 + void make_es_object(const std::vector host_list) { + _es_client = ESClientFactory::create(host_list); + } + //构造mysql客户端对象 + void make_mysql_object( + const std::string &user, + const std::string &pswd, + const std::string &host, + const std::string &db, + const std::string &cset, + int port, + int conn_pool_count) { + _mysql_client = ODBFactory::create(user, pswd, host, db, cset, port, conn_pool_count); + } + //用于构造服务发现客户端&信道管理对象 + void make_discovery_object(const std::string ®_host, + const std::string &base_service_name, + const std::string &user_service_name, + const std::string &message_service_name) { + _user_service_name = user_service_name; + _message_service_name = message_service_name; + _mm_channels = std::make_shared(); + _mm_channels->declared(user_service_name); + _mm_channels->declared(message_service_name); + LOG_DEBUG("设置用户子服务为需添加管理的子服务:{}", user_service_name); + LOG_DEBUG("设置消息子服务为需添加管理的子服务:{}", message_service_name); + auto put_cb = std::bind(&ServiceManager::onServiceOnline, _mm_channels.get(), std::placeholders::_1, std::placeholders::_2); + auto del_cb = std::bind(&ServiceManager::onServiceOffline, _mm_channels.get(), std::placeholders::_1, std::placeholders::_2); + _service_discoverer = std::make_shared(reg_host, base_service_name, put_cb, del_cb); + } + //用于构造服务注册客户端对象 + void make_registry_object(const std::string ®_host, + const std::string &service_name, + const std::string &access_host) { + _registry_client = std::make_shared(reg_host); + _registry_client->registry(service_name, access_host); + } + void make_rpc_server(uint16_t port, int32_t timeout, uint8_t num_threads) { + if (!_es_client) { + LOG_ERROR("还未初始化ES搜索引擎模块!"); + abort(); + } + if (!_mysql_client) { + LOG_ERROR("还未初始化Mysql数据库模块!"); + abort(); + } + if (!_mm_channels) { + LOG_ERROR("还未初始化信道管理模块!"); + abort(); + } + _rpc_server = std::make_shared(); + + FriendServiceImpl *friend_service = new FriendServiceImpl(_es_client, + _mysql_client, _mm_channels, _user_service_name, _message_service_name); + int ret = _rpc_server->AddService(friend_service, + brpc::ServiceOwnership::SERVER_OWNS_SERVICE); + if (ret == -1) { + LOG_ERROR("添加Rpc服务失败!"); + abort(); + } + brpc::ServerOptions options; + options.idle_timeout_sec = timeout; + options.num_threads = num_threads; + ret = _rpc_server->Start(port, &options); + if (ret == -1) { + LOG_ERROR("服务启动失败!"); + abort(); + } + } + //构造RPC服务器对象 + FriendServer::ptr build() { + if (!_service_discoverer) { + LOG_ERROR("还未初始化服务发现模块!"); + abort(); + } + if (!_registry_client) { + LOG_ERROR("还未初始化服务注册模块!"); + abort(); + } + if (!_rpc_server) { + LOG_ERROR("还未初始化RPC服务器模块!"); + abort(); + } + FriendServer::ptr server = std::make_shared( + _service_discoverer, _registry_client, + _es_client, _mysql_client, _rpc_server); + return server; + } + private: + Registry::ptr _registry_client; + + std::shared_ptr _es_client; + std::shared_ptr _mysql_client; + + std::string _user_service_name; + std::string _message_service_name; + ServiceManager::ptr _mm_channels; + Discovery::ptr _service_discoverer; + + std::shared_ptr _rpc_server; +}; +} \ No newline at end of file diff --git a/friend/test/friend_client.cc b/friend/test/friend_client.cc new file mode 100644 index 0000000..8476204 --- /dev/null +++ b/friend/test/friend_client.cc @@ -0,0 +1,286 @@ +#include "etcd.hpp" +#include "channel.hpp" +#include "utils.hpp" +#include +#include +#include +#include "friend.pb.h" + + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +DEFINE_string(etcd_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(friend_service, "/service/friend_service", "服务监控根目录"); + +bite_im::ServiceManager::ptr sm; + +void apply_test(const std::string &uid1, const std::string &uid2) { + auto channel = sm->choose(FLAGS_friend_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::FriendService_Stub stub(channel.get()); + bite_im::FriendAddReq req; + bite_im::FriendAddRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid1); + req.set_respondent_id(uid2); + brpc::Controller cntl; + stub.FriendAdd(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); +} + +void get_apply_list(const std::string &uid1) { + auto channel = sm->choose(FLAGS_friend_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::FriendService_Stub stub(channel.get()); + bite_im::GetPendingFriendEventListReq req; + bite_im::GetPendingFriendEventListRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid1); + brpc::Controller cntl; + stub.GetPendingFriendEventList(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); + for (int i = 0; i < rsp.event_size(); i++) { + std::cout << "---------------\n"; + std::cout << rsp.event(i).sender().user_id() << std::endl; + std::cout << rsp.event(i).sender().nickname() << std::endl; + std::cout << rsp.event(i).sender().avatar() << std::endl; + } +} + +void process_apply_test(const std::string &uid1, bool agree, const std::string &apply_user_id) { + auto channel = sm->choose(FLAGS_friend_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::FriendService_Stub stub(channel.get()); + bite_im::FriendAddProcessReq req; + bite_im::FriendAddProcessRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid1); + req.set_agree(agree); + req.set_apply_user_id(apply_user_id); + brpc::Controller cntl; + stub.FriendAddProcess(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); + if (agree) { + std::cout << rsp.new_session_id() << std::endl; + } +} + +void search_test(const std::string &uid1, const std::string &key) { + auto channel = sm->choose(FLAGS_friend_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::FriendService_Stub stub(channel.get()); + bite_im::FriendSearchReq req; + bite_im::FriendSearchRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid1); + req.set_search_key(key); + brpc::Controller cntl; + stub.FriendSearch(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); + for (int i = 0; i < rsp.user_info_size(); i++) { + std::cout << "-------------------\n"; + std::cout << rsp.user_info(i).user_id() << std::endl; + std::cout << rsp.user_info(i).nickname() << std::endl; + std::cout << rsp.user_info(i).avatar() << std::endl; + } +} + +void friend_list_test(const std::string &uid1) { + auto channel = sm->choose(FLAGS_friend_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::FriendService_Stub stub(channel.get()); + bite_im::GetFriendListReq req; + bite_im::GetFriendListRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid1); + brpc::Controller cntl; + stub.GetFriendList(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); + for (int i = 0; i < rsp.friend_list_size(); i++) { + std::cout << "-------------------\n"; + std::cout << rsp.friend_list(i).user_id() << std::endl; + std::cout << rsp.friend_list(i).nickname() << std::endl; + std::cout << rsp.friend_list(i).avatar() << std::endl; + } +} + + +void remove_test(const std::string &uid1, const std::string &uid2) { + auto channel = sm->choose(FLAGS_friend_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::FriendService_Stub stub(channel.get()); + bite_im::FriendRemoveReq req; + bite_im::FriendRemoveRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid1); + req.set_peer_id(uid2); + brpc::Controller cntl; + stub.FriendRemove(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); +} +void create_css_test(const std::string &uid1, const std::vector &uidlist) { + auto channel = sm->choose(FLAGS_friend_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::FriendService_Stub stub(channel.get()); + bite_im::ChatSessionCreateReq req; + bite_im::ChatSessionCreateRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid1); + req.set_chat_session_name("快乐一家人"); + for (auto &id : uidlist) { + req.add_member_id_list(id); + } + brpc::Controller cntl; + stub.ChatSessionCreate(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); + std::cout << rsp.chat_session_info().chat_session_id() << std::endl; + std::cout << rsp.chat_session_info().chat_session_name() << std::endl; +} + + +void cssmember_test(const std::string &uid1, const std::string &cssid) { + auto channel = sm->choose(FLAGS_friend_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::FriendService_Stub stub(channel.get()); + bite_im::GetChatSessionMemberReq req; + bite_im::GetChatSessionMemberRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid1); + req.set_chat_session_id(cssid); + brpc::Controller cntl; + stub.GetChatSessionMember(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); + for (int i = 0; i < rsp.member_info_list_size(); i++) { + std::cout << "-------------------\n"; + std::cout << rsp.member_info_list(i).user_id() << std::endl; + std::cout << rsp.member_info_list(i).nickname() << std::endl; + std::cout << rsp.member_info_list(i).avatar() << std::endl; + } +} + + +void csslist_test(const std::string &uid1) { + auto channel = sm->choose(FLAGS_friend_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::FriendService_Stub stub(channel.get()); + bite_im::GetChatSessionListReq req; + bite_im::GetChatSessionListRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid1); + brpc::Controller cntl; + std::cout << "发送获取聊天会话列表请求!!\n"; + stub.GetChatSessionList(&cntl, &req, &rsp, nullptr); + std::cout << "请求发送完毕1!!\n"; + ASSERT_FALSE(cntl.Failed()); + std::cout << "请求发送完毕2!!\n"; + ASSERT_TRUE(rsp.success()); + std::cout << "请求发送完毕,且成功!!\n"; + for (int i = 0; i < rsp.chat_session_info_list_size(); i++) { + std::cout << "-------------------\n"; + std::cout << rsp.chat_session_info_list(i).single_chat_friend_id() << std::endl; + std::cout << rsp.chat_session_info_list(i).chat_session_id() << std::endl; + std::cout << rsp.chat_session_info_list(i).chat_session_name() << std::endl; + std::cout << rsp.chat_session_info_list(i).avatar() << std::endl; + std::cout << "消息内容:\n"; + std::cout << rsp.chat_session_info_list(i).prev_message().message_id() << std::endl; + std::cout << rsp.chat_session_info_list(i).prev_message().chat_session_id() << std::endl; + std::cout << rsp.chat_session_info_list(i).prev_message().timestamp() << std::endl; + std::cout << rsp.chat_session_info_list(i).prev_message().sender().user_id() << std::endl; + std::cout << rsp.chat_session_info_list(i).prev_message().sender().nickname() << std::endl; + std::cout << rsp.chat_session_info_list(i).prev_message().sender().avatar() << std::endl; + std::cout << rsp.chat_session_info_list(i).prev_message().message().file_message().file_name() << std::endl; + std::cout << rsp.chat_session_info_list(i).prev_message().message().file_message().file_contents() << std::endl; + } +} + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + + //1. 先构造Rpc信道管理对象 + sm = std::make_shared(); + sm->declared(FLAGS_friend_service); + auto put_cb = std::bind(&bite_im::ServiceManager::onServiceOnline, sm.get(), std::placeholders::_1, std::placeholders::_2); + auto del_cb = std::bind(&bite_im::ServiceManager::onServiceOffline, sm.get(), std::placeholders::_1, std::placeholders::_2); + //2. 构造服务发现对象 + bite_im::Discovery::ptr dclient = std::make_shared(FLAGS_etcd_host, FLAGS_base_service, put_cb, del_cb); + + + // apply_test("ee55-9043bfd7-0001", "672f-c755e83e-0000"); + // apply_test("67b1-35ca1b76-0000", "672f-c755e83e-0000"); + // apply_test("d9ff-65692d4a-0001", "672f-c755e83e-0000"); + + // get_apply_list("672f-c755e83e-0000"); + + // process_apply_test("672f-c755e83e-0000", true, "ee55-9043bfd7-0001"); + // process_apply_test("672f-c755e83e-0000", false, "67b1-35ca1b76-0000"); + // process_apply_test("672f-c755e83e-0000", true, "d9ff-65692d4a-0001"); + + // std::cout << "**********************\n"; + // search_test("672f-c755e83e-0000", "猪"); + // std::cout << "++++++++++++++++++++++\n"; + // search_test("ee55-9043bfd7-0001", "猪"); + // std::cout << "======================\n"; + // search_test("67b1-35ca1b76-0000", "乔治"); + + // friend_list_test("c4dc-68239a9a-0001"); + // std::cout << "++++++++++++++++++++++\n"; + // friend_list_test("731f-50086884-0000"); + // std::cout << "++++++++++++++++++++++\n"; + // friend_list_test("31ab-86a1209d-0000"); + + // remove_test("c4dc-68239a9a-0001", "053f-04e5e4c5-0001"); + + // std::vector uidlist = { + // "731f-50086884-0000", + // "c4dc-68239a9a-0001", + // "31ab-86a1209d-0000", + // "053f-04e5e4c5-0001"}; + // create_css_test("731f-50086884-0000", uidlist); + // cssmember_test("731f-50086884-0000", "36b5-edaf4987-0000"); + // std::cout << "++++++++++++++++++++++\n"; + // cssmember_test("c4dc-68239a9a-0001", "36b5-edaf4987-0000"); + + // csslist_test("c4dc-68239a9a-0001"); + return 0; +} \ No newline at end of file diff --git a/friend/test/mysql_test/main.cc b/friend/test/mysql_test/main.cc new file mode 100644 index 0000000..b139e84 --- /dev/null +++ b/friend/test/mysql_test/main.cc @@ -0,0 +1,121 @@ +#include "mysql_chat_session.hpp" +#include "mysql_apply.hpp" +#include "mysql_relation.hpp" +#include + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +void r_insert_test(bite_im::RelationTable &tb) { + tb.insert("用户ID1", "用户ID2"); + tb.insert("用户ID1", "用户ID3"); +} +void r_select_test(bite_im::RelationTable &tb) { + auto res = tb.friends("用户ID1"); + for (auto &uid:res) { + std::cout << uid << std::endl; + } +} +void r_remove_test(bite_im::RelationTable &tb) { + tb.remove("用户ID2", "用户ID1"); +} + +void r_exists_test(bite_im::RelationTable &tb) { + std::cout << tb.exists("用户ID2", "用户ID1") << std::endl; + std::cout << tb.exists("用户ID3", "用户ID1") << std::endl; +} + +void a_insert_test(bite_im::FriendApplyTable &tb) { + bite_im::FriendApply fa1("uuid1", "用户ID1", "用户ID2"); + tb.insert(fa1); + + bite_im::FriendApply fa2("uuid2", "用户ID1", "用户ID3"); + tb.insert(fa2); + + bite_im::FriendApply fa3("uuid3", "用户ID2", "用户ID3"); + tb.insert(fa3); +} +void a_remove_test(bite_im::FriendApplyTable &tb) { + tb.remove("用户ID2", "用户ID3"); +} + +void a_select_test(bite_im::FriendApplyTable &tb) { + // bite_im::FriendApply fa3("uuid3", "用户ID2", "用户ID3"); + // tb.insert(fa3); + + auto res = tb.applyUsers("用户ID2"); + for (auto &uid:res) { + std::cout << uid << std::endl; + } +} +void a_exists_test(bite_im::FriendApplyTable &tb) { + std::cout << tb.exists("用户ID1", "用户ID2") << std::endl; + std::cout << tb.exists("31ab-86a1209d-0000", "c4dc-68239a9a-0001") << std::endl; + std::cout << tb.exists("053f-04e5e4c5-0001", "c4dc-68239a9a-0001") << std::endl; +} + +void c_insert_test(bite_im::ChatSessionTable &tb) { + bite_im::ChatSession cs1("会话ID1", "会话名称1", bite_im::ChatSessionType::SINGLE); + tb.insert(cs1); + bite_im::ChatSession cs2("会话ID2", "会话名称2", bite_im::ChatSessionType::GROUP); + tb.insert(cs2); +} + + +void c_select_test(bite_im::ChatSessionTable &tb) { + auto res = tb.select("会话ID1"); + std::cout << res->chat_session_id() << std::endl; + std::cout << res->chat_session_name() << std::endl; + std::cout << (int)res->chat_session_type() << std::endl; +} + +void c_single_test(bite_im::ChatSessionTable &tb) { + auto res = tb.singleChatSession("731f-50086884-0000"); + for (auto &info : res) { + std::cout << info.chat_session_id << std::endl; + std::cout << info.friend_id << std::endl; + } +} +void c_group_test(bite_im::ChatSessionTable &tb) { + auto res = tb.groupChatSession("用户ID1"); + for (auto &info : res) { + std::cout << info.chat_session_id << std::endl; + std::cout << info.chat_session_name << std::endl; + } +} +void c_remove_test(bite_im::ChatSessionTable &tb) { + tb.remove("会话ID3"); +} +void c_remove_test2(bite_im::ChatSessionTable &tb) { + tb.remove("731f-50086884-0000", "c4dc-68239a9a-0001"); +} + + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + auto db = bite_im::ODBFactory::create("root", "123456", "127.0.0.1", "bite_im", "utf8", 0, 1); + bite_im::RelationTable rtb(db); + bite_im::FriendApplyTable fatb(db); + bite_im::ChatSessionTable cstb(db); + + // r_insert_test(rtb); + // r_select_test(rtb); + // r_remove_test(rtb); + // r_exists_test(rtb); + // a_insert_test(fatb); + // a_remove_test(fatb); + // a_select_test(fatb); + // a_exists_test(fatb); + // c_insert_test(cstb); + // c_select_test(cstb); + // c_single_test(cstb); + // std::cout << "--------------\n"; + // c_group_test(cstb); + // c_remove_test(cstb); + // c_remove_test2(cstb); + return 0; +} \ No newline at end of file diff --git a/gateway/CMakeLists.txt b/gateway/CMakeLists.txt new file mode 100644 index 0000000..157b238 --- /dev/null +++ b/gateway/CMakeLists.txt @@ -0,0 +1,56 @@ +# 1. 添加cmake版本说明 +cmake_minimum_required(VERSION 3.1.3) +# 2. 声明工程名称 +project(gateway_server) + +set(target "gateway_server") + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") + +# 3. 检测并生成ODB框架代码 +# 1. 添加所需的proto映射代码文件名称 +set(proto_path ${CMAKE_CURRENT_SOURCE_DIR}/../proto) +set(proto_files base.proto user.proto file.proto friend.proto gateway.proto message.proto notify.proto speech.proto transmite.proto ) +# 2. 检测框架代码文件是否已经生成 +set(proto_hxx "") +set(proto_cxx "") +set(proto_srcs "") +foreach(proto_file ${proto_files}) +# 3. 如果没有生成,则预定义生成指令 -- 用于在构建项目之间先生成框架代码 + string(REPLACE ".proto" ".pb.cc" proto_cc ${proto_file}) + string(REPLACE ".proto" ".pb.h" proto_hh ${proto_file}) + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}${proto_cc}) + add_custom_command( + PRE_BUILD + COMMAND protoc + ARGS --cpp_out=${CMAKE_CURRENT_BINARY_DIR} -I ${proto_path} --experimental_allow_proto3_optional ${proto_path}/${proto_file} + DEPENDS ${proto_path}/${proto_file} + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + COMMENT "生成Protobuf框架代码文件:" ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + ) + endif() + list(APPEND proto_srcs ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc}) +endforeach() + +# 4. 获取源码目录下的所有源码文件 +set(src_files "") +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/source src_files) +# 5. 声明目标及依赖 +add_executable(${target} ${src_files} ${proto_srcs} ${odb_srcs}) +# 7. 设置需要连接的库 +target_link_libraries(${target} -lgflags + -lspdlog -lfmt -lbrpc -lssl -lcrypto + -lprotobuf -lleveldb -letcd-cpp-api + -lodb-mysql -lodb -lodb-boost + -lhiredis -lredis++ + -lcpprest -lcurl + -lpthread -lboost_system) + + +# 6. 设置头文件默认搜索路径 +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../common) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third/include) + +#8. 设置安装路径 +INSTALL(TARGETS ${target} RUNTIME DESTINATION bin) \ No newline at end of file diff --git a/gateway/dockerfile b/gateway/dockerfile new file mode 100644 index 0000000..904c7da --- /dev/null +++ b/gateway/dockerfile @@ -0,0 +1,16 @@ +# 声明基础经镜像来源 +FROM debian:12 + +# 声明工作目录 +WORKDIR /im +RUN mkdir -p /im/logs &&\ + mkdir -p /im/data &&\ + mkdir -p /im/conf &&\ + mkdir -p /im/bin + +# 将可执行程序依赖,拷贝进镜像 +COPY ./build/gateway_server /im/bin/ +# 将可执行程序文件,拷贝进镜像 +COPY ./depends /lib/x86_64-linux-gnu/ +# 设置容器启动的默认操作 ---运行程序 +CMD /im/bin/gateway_server -flagfile=/im/conf/gateway_server.conf \ No newline at end of file diff --git a/gateway/gateway_server.conf b/gateway/gateway_server.conf new file mode 100644 index 0000000..8c98598 --- /dev/null +++ b/gateway/gateway_server.conf @@ -0,0 +1,17 @@ +-run_mode=true +-log_file=/im/logs/gateway.log +-log_level=0 +-http_listen_port=9000 +-websocket_listen_port=9001 +-registry_host=http://10.0.0.235:2379 +-base_service=/service +-file_service=/service/file_service +-friend_service=/service/friend_service +-message_service=/service/message_service +-user_service=/service/user_service +-speech_service=/service/speech_service +-transmite_service=/service/transmite_service +-redis_host=10.0.0.235 +-redis_port=6379 +-redis_db=0 +-redis_keep_alive=true \ No newline at end of file diff --git a/gateway/source/connection.hpp b/gateway/source/connection.hpp new file mode 100644 index 0000000..9431066 --- /dev/null +++ b/gateway/source/connection.hpp @@ -0,0 +1,64 @@ +#include +#include +#include "logger.hpp" + +namespace bite_im { +typedef websocketpp::server server_t; +// 连接的类型: server_t::connection_ptr + +class Connection { + public: + struct Client { + Client(const std::string &u, const std::string &s):uid(u), ssid(s){} + std::string uid; + std::string ssid; + }; + using ptr = std::shared_ptr; + Connection(){} + ~Connection() {} + void insert(const server_t::connection_ptr &conn, + const std::string &uid, const std::string &ssid) { + std::unique_lock lock(_mutex); + _uid_connections.insert(std::make_pair(uid, conn)); + _conn_clients.insert(std::make_pair(conn, Client(uid, ssid))); + LOG_DEBUG("新增长连接用户信息:{}-{}-{}", (size_t)conn.get(), uid, ssid); + } + server_t::connection_ptr connection(const std::string &uid) { + std::unique_lock lock(_mutex); + auto it = _uid_connections.find(uid); + if (it == _uid_connections.end()) { + LOG_ERROR("未找到 {} 客户端的长连接!", uid); + return server_t::connection_ptr(); + } + LOG_DEBUG("找到 {} 客户端的长连接!", uid); + return it->second; + } + bool client(const server_t::connection_ptr &conn, std::string &uid, std::string &ssid) { + std::unique_lock lock(_mutex); + auto it = _conn_clients.find(conn); + if (it == _conn_clients.end()) { + LOG_ERROR("获取-未找到长连接 {} 对应的客户端信息!", (size_t)conn.get()); + return false; + } + uid = it->second.uid; + ssid = it->second.ssid; + LOG_DEBUG("获取长连接客户端信息成功!"); + return true; + } + void remove(const server_t::connection_ptr &conn) { + std::unique_lock lock(_mutex); + auto it = _conn_clients.find(conn); + if (it == _conn_clients.end()) { + LOG_ERROR("删除-未找到长连接 {} 对应的客户端信息!", (size_t)conn.get()); + return; + } + _uid_connections.erase(it->second.uid); + _conn_clients.erase(it); + LOG_DEBUG("删除长连接信息完毕!"); + } + private: + std::mutex _mutex; + std::unordered_map _uid_connections; + std::unordered_map _conn_clients; +}; +} \ No newline at end of file diff --git a/gateway/source/gateway_server.cc b/gateway/source/gateway_server.cc new file mode 100644 index 0000000..b7505f0 --- /dev/null +++ b/gateway/source/gateway_server.cc @@ -0,0 +1,39 @@ +//主要实现语音识别子服务的服务器的搭建 +#include "gateway_server.hpp" + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +DEFINE_int32(http_listen_port, 9000, "HTTP服务器监听端口"); +DEFINE_int32(websocket_listen_port, 9001, "Websocket服务器监听端口"); + +DEFINE_string(registry_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(file_service, "/service/file_service", "文件存储子服务名称"); +DEFINE_string(friend_service, "/service/friend_service", "好友管理子服务名称"); +DEFINE_string(message_service, "/service/message_service", "消息存储子服务名称"); +DEFINE_string(user_service, "/service/user_service", "用户管理子服务名称"); +DEFINE_string(speech_service, "/service/speech_service", "语音识别子服务名称"); +DEFINE_string(transmite_service, "/service/transmite_service", "转发管理子服务名称"); + +DEFINE_string(redis_host, "127.0.0.1", "Redis服务器访问地址"); +DEFINE_int32(redis_port, 6379, "Redis服务器访问端口"); +DEFINE_int32(redis_db, 0, "Redis默认库号"); +DEFINE_bool(redis_keep_alive, true, "Redis长连接保活选项"); + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + bite_im::GatewayServerBuilder gsb; + gsb.make_redis_object(FLAGS_redis_host, FLAGS_redis_port, FLAGS_redis_db, FLAGS_redis_keep_alive); + gsb.make_discovery_object(FLAGS_registry_host, FLAGS_base_service, FLAGS_file_service, + FLAGS_speech_service, FLAGS_message_service, FLAGS_friend_service, + FLAGS_user_service, FLAGS_transmite_service); + gsb.make_server_object(FLAGS_websocket_listen_port, FLAGS_http_listen_port); + auto server = gsb.build(); + server->start(); + return 0; +} \ No newline at end of file diff --git a/gateway/source/gateway_server.hpp b/gateway/source/gateway_server.hpp new file mode 100644 index 0000000..62127a5 --- /dev/null +++ b/gateway/source/gateway_server.hpp @@ -0,0 +1,1421 @@ +#include "data_redis.hpp" // redis数据管理客户端封装 +#include "etcd.hpp" // 服务注册模块封装 +#include "logger.hpp" // 日志模块封装 +#include "channel.hpp" // 信道管理模块封装 + +#include "connection.hpp" + +#include "user.pb.h" // protobuf框架代码 +#include "base.pb.h" // protobuf框架代码 +#include "file.pb.h" // protobuf框架代码 +#include "friend.pb.h" // protobuf框架代码 +#include "gateway.pb.h" // protobuf框架代码 +#include "message.pb.h" // protobuf框架代码 +#include "speech.pb.h" // protobuf框架代码 +#include "transmite.pb.h" // protobuf框架代码 +#include "notify.pb.h" + +#include "httplib.h" + + +namespace bite_im{ + #define GET_PHONE_VERIFY_CODE "/service/user/get_phone_verify_code" + #define USERNAME_REGISTER "/service/user/username_register" + #define USERNAME_LOGIN "/service/user/username_login" + #define PHONE_REGISTER "/service/user/phone_register" + #define PHONE_LOGIN "/service/user/phone_login" + #define GET_USERINFO "/service/user/get_user_info" + #define SET_USER_AVATAR "/service/user/set_avatar" + #define SET_USER_NICKNAME "/service/user/set_nickname" + #define SET_USER_DESC "/service/user/set_description" + #define SET_USER_PHONE "/service/user/set_phone" + #define FRIEND_GET_LIST "/service/friend/get_friend_list" + #define FRIEND_APPLY "/service/friend/add_friend_apply" + #define FRIEND_APPLY_PROCESS "/service/friend/add_friend_process" + #define FRIEND_REMOVE "/service/friend/remove_friend" + #define FRIEND_SEARCH "/service/friend/search_friend" + #define FRIEND_GET_PENDING_EV "/service/friend/get_pending_friend_events" + #define CSS_GET_LIST "/service/friend/get_chat_session_list" + #define CSS_CREATE "/service/friend/create_chat_session" + #define CSS_GET_MEMBER "/service/friend/get_chat_session_member" + #define MSG_GET_RANGE "/service/message_storage/get_history" + #define MSG_GET_RECENT "/service/message_storage/get_recent" + #define MSG_KEY_SEARCH "/service/message_storage/search_history" + #define NEW_MESSAGE "/service/message_transmit/new_message" + #define FILE_GET_SINGLE "/service/file/get_single_file" + #define FILE_GET_MULTI "/service/file/get_multi_file" + #define FILE_PUT_SINGLE "/service/file/put_single_file" + #define FILE_PUT_MULTI "/service/file/put_multi_file" + #define SPEECH_RECOGNITION "/service/speech/recognition" + class GatewayServer { + public: + using ptr = std::shared_ptr; + GatewayServer( + int websocket_port, + int http_port, + const std::shared_ptr &redis_client, + const ServiceManager::ptr &channels, + const Discovery::ptr &service_discoverer, + const std::string user_service_name, + const std::string file_service_name, + const std::string speech_service_name, + const std::string message_service_name, + const std::string transmite_service_name, + const std::string friend_service_name) + :_redis_session(std::make_shared(redis_client)), + _redis_status(std::make_shared(redis_client)), + _mm_channels(channels), + _service_discoverer(service_discoverer), + _user_service_name(user_service_name), + _file_service_name(file_service_name), + _speech_service_name(speech_service_name), + _message_service_name(message_service_name), + _transmite_service_name(transmite_service_name), + _friend_service_name(friend_service_name), + _connections(std::make_shared()){ + + _ws_server.set_access_channels(websocketpp::log::alevel::none); + _ws_server.init_asio(); + _ws_server.set_open_handler(std::bind(&GatewayServer::onOpen, this, std::placeholders::_1)); + _ws_server.set_close_handler(std::bind(&GatewayServer::onClose, this, std::placeholders::_1)); + auto wscb = std::bind(&GatewayServer::onMessage, this, + std::placeholders::_1, std::placeholders::_2); + _ws_server.set_message_handler(wscb); + _ws_server.set_reuse_addr(true); + _ws_server.listen(websocket_port); + _ws_server.start_accept(); + + _http_server.Post(GET_PHONE_VERIFY_CODE , (httplib::Server::Handler)std::bind(&GatewayServer::GetPhoneVerifyCode , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(USERNAME_REGISTER , (httplib::Server::Handler)std::bind(&GatewayServer::UserRegister , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(USERNAME_LOGIN , (httplib::Server::Handler)std::bind(&GatewayServer::UserLogin , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(PHONE_REGISTER , (httplib::Server::Handler)std::bind(&GatewayServer::PhoneRegister , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(PHONE_LOGIN , (httplib::Server::Handler)std::bind(&GatewayServer::PhoneLogin , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(GET_USERINFO , (httplib::Server::Handler)std::bind(&GatewayServer::GetUserInfo , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(SET_USER_AVATAR , (httplib::Server::Handler)std::bind(&GatewayServer::SetUserAvatar , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(SET_USER_NICKNAME , (httplib::Server::Handler)std::bind(&GatewayServer::SetUserNickname , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(SET_USER_DESC , (httplib::Server::Handler)std::bind(&GatewayServer::SetUserDescription , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(SET_USER_PHONE , (httplib::Server::Handler)std::bind(&GatewayServer::SetUserPhoneNumber , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(FRIEND_GET_LIST , (httplib::Server::Handler)std::bind(&GatewayServer::GetFriendList , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(FRIEND_APPLY , (httplib::Server::Handler)std::bind(&GatewayServer::FriendAdd , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(FRIEND_APPLY_PROCESS , (httplib::Server::Handler)std::bind(&GatewayServer::FriendAddProcess , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(FRIEND_REMOVE , (httplib::Server::Handler)std::bind(&GatewayServer::FriendRemove , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(FRIEND_SEARCH , (httplib::Server::Handler)std::bind(&GatewayServer::FriendSearch , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(FRIEND_GET_PENDING_EV , (httplib::Server::Handler)std::bind(&GatewayServer::GetPendingFriendEventList , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(CSS_GET_LIST , (httplib::Server::Handler)std::bind(&GatewayServer::GetChatSessionList , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(CSS_CREATE , (httplib::Server::Handler)std::bind(&GatewayServer::ChatSessionCreate , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(CSS_GET_MEMBER , (httplib::Server::Handler)std::bind(&GatewayServer::GetChatSessionMember , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(MSG_GET_RANGE , (httplib::Server::Handler)std::bind(&GatewayServer::GetHistoryMsg , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(MSG_GET_RECENT , (httplib::Server::Handler)std::bind(&GatewayServer::GetRecentMsg , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(MSG_KEY_SEARCH , (httplib::Server::Handler)std::bind(&GatewayServer::MsgSearch , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(NEW_MESSAGE , (httplib::Server::Handler)std::bind(&GatewayServer::NewMessage , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(FILE_GET_SINGLE , (httplib::Server::Handler)std::bind(&GatewayServer::GetSingleFile , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(FILE_GET_MULTI , (httplib::Server::Handler)std::bind(&GatewayServer::GetMultiFile , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(FILE_PUT_SINGLE , (httplib::Server::Handler)std::bind(&GatewayServer::PutSingleFile , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(FILE_PUT_MULTI , (httplib::Server::Handler)std::bind(&GatewayServer::PutMultiFile , this, std::placeholders::_1, std::placeholders::_2)); + _http_server.Post(SPEECH_RECOGNITION , (httplib::Server::Handler)std::bind(&GatewayServer::SpeechRecognition , this, std::placeholders::_1, std::placeholders::_2)); + _http_thread = std::thread([this, http_port](){ + _http_server.listen("0.0.0.0", http_port); + }); + _http_thread.detach(); + } + void start() { + _ws_server.run(); + } + private: + void onOpen(websocketpp::connection_hdl hdl) { + LOG_DEBUG("websocket长连接建立成功 {}", (size_t)_ws_server.get_con_from_hdl(hdl).get()); + } + void onClose(websocketpp::connection_hdl hdl) { + //长连接断开时做的清理工作 + //0. 通过连接对象,获取对应的用户ID与登录会话ID + auto conn = _ws_server.get_con_from_hdl(hdl); + std::string uid, ssid; + bool ret = _connections->client(conn, uid, ssid); + if (ret == false) { + LOG_WARN("长连接断开,未找到长连接对应的客户端信息!"); + return ; + } + //1. 移除登录会话信息 + _redis_session->remove(ssid); + //2. 移除登录状态信息 + _redis_status->remove(uid); + //3. 移除长连接管理数据 + _connections->remove(conn); + LOG_DEBUG("{} {} {} 长连接断开,清理缓存数据!", ssid, uid, (size_t)conn.get()); + } + void keepAlive(server_t::connection_ptr conn) { + if (!conn || conn->get_state() != websocketpp::session::state::value::open) { + LOG_DEBUG("非正常连接状态,结束连接保活"); + return; + } + conn->ping(""); + _ws_server.set_timer(60000, std::bind(&GatewayServer::keepAlive, this, conn)); + } + void onMessage(websocketpp::connection_hdl hdl, server_t::message_ptr msg) { + //收到第一条消息后,根据消息中的会话ID进行身份识别,将客户端长连接添加管理 + //1. 取出长连接对应的连接对象 + auto conn = _ws_server.get_con_from_hdl(hdl); + //2. 针对消息内容进行反序列化 -- ClientAuthenticationReq -- 提取登录会话ID + ClientAuthenticationReq request; + bool ret = request.ParseFromString(msg->get_payload()); + if (ret == false) { + LOG_ERROR("长连接身份识别失败:正文反序列化失败!"); + _ws_server.close(hdl, websocketpp::close::status::unsupported_data, "正文反序列化失败!"); + return; + } + //3. 在会话信息缓存中,查找会话信息 + std::string ssid = request.session_id(); + auto uid = _redis_session->uid(ssid); + //4. 会话信息不存在则关闭连接 + if (!uid) { + LOG_ERROR("长连接身份识别失败:未找到会话信息 {}!", ssid); + _ws_server.close(hdl, websocketpp::close::status::unsupported_data, "未找到会话信息!"); + return; + } + //5. 会话信息存在,则添加长连接管理 + _connections->insert(conn, *uid, ssid); + LOG_DEBUG("新增长连接管理:{}-{}-{}", ssid, *uid, (size_t)conn.get()); + keepAlive(conn); + } + void GetPhoneVerifyCode(const httplib::Request &request, httplib::Response &response) { + //1. 取出http请求正文,将正文进行反序列化 + PhoneVerifyCodeReq req; + PhoneVerifyCodeRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("获取短信验证码请求正文反序列化失败!"); + return err_response("获取短信验证码请求正文反序列化失败!"); + } + //2. 将请求转发给用户子服务进行业务处理 + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::UserService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.GetPhoneVerifyCode(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 用户子服务调用失败!", req.request_id()); + return err_response("用户子服务调用失败!"); + } + //3. 得到用户子服务的响应后,将响应内容进行序列化作为http响应正文 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void UserRegister(const httplib::Request &request, httplib::Response &response) { + //1. 取出http请求正文,将正文进行反序列化 + UserRegisterReq req; + UserRegisterRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("用户名注册请求正文反序列化失败!"); + return err_response("用户名注册请求正文反序列化失败!"); + } + //2. 将请求转发给用户子服务进行业务处理 + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::UserService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.UserRegister(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 用户子服务调用失败!", req.request_id()); + return err_response("用户子服务调用失败!"); + } + //3. 得到用户子服务的响应后,将响应内容进行序列化作为http响应正文 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void UserLogin(const httplib::Request &request, httplib::Response &response) { + //1. 取出http请求正文,将正文进行反序列化 + UserLoginReq req; + UserLoginRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("用户登录请求正文反序列化失败!"); + return err_response("用户登录请求正文反序列化失败!"); + } + //2. 将请求转发给用户子服务进行业务处理 + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::UserService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.UserLogin(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 用户子服务调用失败!", req.request_id()); + return err_response("用户子服务调用失败!"); + } + //3. 得到用户子服务的响应后,将响应内容进行序列化作为http响应正文 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void PhoneRegister(const httplib::Request &request, httplib::Response &response) { + //1. 取出http请求正文,将正文进行反序列化 + PhoneRegisterReq req; + PhoneRegisterRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("手机号注册请求正文反序列化失败!"); + return err_response("手机号注册请求正文反序列化失败!"); + } + //2. 将请求转发给用户子服务进行业务处理 + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::UserService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.PhoneRegister(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 用户子服务调用失败!", req.request_id()); + return err_response("用户子服务调用失败!"); + } + //3. 得到用户子服务的响应后,将响应内容进行序列化作为http响应正文 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void PhoneLogin(const httplib::Request &request, httplib::Response &response) { + //1. 取出http请求正文,将正文进行反序列化 + PhoneLoginReq req; + PhoneLoginRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("手机号登录请求正文反序列化失败!"); + return err_response("手机号登录请求正文反序列化失败!"); + } + //2. 将请求转发给用户子服务进行业务处理 + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::UserService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.PhoneLogin(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 用户子服务调用失败!", req.request_id()); + return err_response("用户子服务调用失败!"); + } + //3. 得到用户子服务的响应后,将响应内容进行序列化作为http响应正文 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void GetUserInfo(const httplib::Request &request, httplib::Response &response) { + //1. 取出http请求正文,将正文进行反序列化 + GetUserInfoReq req; + GetUserInfoRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("获取用户信息请求正文反序列化失败!"); + return err_response("获取用户信息请求正文反序列化失败!"); + } + //2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + //2. 将请求转发给用户子服务进行业务处理 + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::UserService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.GetUserInfo(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 用户子服务调用失败!", req.request_id()); + return err_response("用户子服务调用失败!"); + } + //3. 得到用户子服务的响应后,将响应内容进行序列化作为http响应正文 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void SetUserAvatar(const httplib::Request &request, httplib::Response &response) { + //1. 取出http请求正文,将正文进行反序列化 + SetUserAvatarReq req; + SetUserAvatarRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("用户头像设置请求正文反序列化失败!"); + return err_response("用户头像设置请求正文反序列化失败!"); + } + //2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + //2. 将请求转发给用户子服务进行业务处理 + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::UserService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.SetUserAvatar(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 用户子服务调用失败!", req.request_id()); + return err_response("用户子服务调用失败!"); + } + //3. 得到用户子服务的响应后,将响应内容进行序列化作为http响应正文 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void SetUserNickname(const httplib::Request &request, httplib::Response &response) { + //1. 取出http请求正文,将正文进行反序列化 + SetUserNicknameReq req; + SetUserNicknameRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("用户昵称设置请求正文反序列化失败!"); + return err_response("用户昵称设置请求正文反序列化失败!"); + } + //2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + //2. 将请求转发给用户子服务进行业务处理 + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::UserService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.SetUserNickname(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 用户子服务调用失败!", req.request_id()); + return err_response("用户子服务调用失败!"); + } + //3. 得到用户子服务的响应后,将响应内容进行序列化作为http响应正文 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void SetUserDescription(const httplib::Request &request, httplib::Response &response) { + //1. 取出http请求正文,将正文进行反序列化 + SetUserDescriptionReq req; + SetUserDescriptionRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("用户签名设置请求正文反序列化失败!"); + return err_response("用户签名设置请求正文反序列化失败!"); + } + //2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + //2. 将请求转发给用户子服务进行业务处理 + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::UserService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.SetUserDescription(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 用户子服务调用失败!", req.request_id()); + return err_response("用户子服务调用失败!"); + } + //3. 得到用户子服务的响应后,将响应内容进行序列化作为http响应正文 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void SetUserPhoneNumber(const httplib::Request &request, httplib::Response &response) { + //1. 取出http请求正文,将正文进行反序列化 + SetUserPhoneNumberReq req; + SetUserPhoneNumberRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("用户手机号设置请求正文反序列化失败!"); + return err_response("用户手机号设置请求正文反序列化失败!"); + } + //2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + //2. 将请求转发给用户子服务进行业务处理 + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::UserService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.SetUserPhoneNumber(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 用户子服务调用失败!", req.request_id()); + return err_response("用户子服务调用失败!"); + } + //3. 得到用户子服务的响应后,将响应内容进行序列化作为http响应正文 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void GetFriendList(const httplib::Request &request, httplib::Response &response) { + //1. 取出http请求正文,将正文进行反序列化 + GetFriendListReq req; + GetFriendListRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("获取好友列表请求正文反序列化失败!"); + return err_response("获取好友列表请求正文反序列化失败!"); + } + //2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + //2. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_friend_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::FriendService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.GetFriendList(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 好友子服务调用失败!", req.request_id()); + return err_response("好友子服务调用失败!"); + } + //3. 得到用户子服务的响应后,将响应内容进行序列化作为http响应正文 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + + std::shared_ptr _GetUserInfo(const std::string &rid, const std::string &uid) { + GetUserInfoReq req; + auto rsp = std::make_shared(); + req.set_request_id(rid); + req.set_user_id(uid); + //2. 将请求转发给用户子服务进行业务处理 + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return std::shared_ptr(); + } + bite_im::UserService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.GetUserInfo(&cntl, &req, rsp.get(), nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 用户子服务调用失败!", req.request_id()); + return std::shared_ptr(); + } + return rsp; + } + void FriendAdd(const httplib::Request &request, httplib::Response &response) { + // 好友申请的业务处理中,好友子服务其实只是在数据库创建了申请事件 + // 网关需要做的事情:当好友子服务将业务处理完毕后,如果处理是成功的--需要通知被申请方 + // 1. 正文的反序列化,提取关键要素:登录会话ID + FriendAddReq req; + FriendAddRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("申请好友请求正文反序列化失败!"); + return err_response("申请好友请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_friend_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::FriendService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.FriendAdd(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 好友子服务调用失败!", req.request_id()); + return err_response("好友子服务调用失败!"); + } + // 4. 若业务处理成功 --- 且获取被申请方长连接成功,则向被申请放进行好友申请事件通知 + auto conn = _connections->connection(req.respondent_id()); + if (rsp.success() && conn) { + LOG_DEBUG("找到被申请人 {} 长连接,对其进行好友申请通知", req.respondent_id()); + auto user_rsp = _GetUserInfo(req.request_id(), *uid); + if (!user_rsp) { + LOG_ERROR("{} 获取当前客户端用户信息失败!", req.request_id()); + return err_response("获取当前客户端用户信息失败!"); + } + NotifyMessage notify; + notify.set_notify_type(NotifyType::FRIEND_ADD_APPLY_NOTIFY); + notify.mutable_friend_add_apply()->mutable_user_info()->CopyFrom(user_rsp->user_info()); + conn->send(notify.SerializeAsString(), websocketpp::frame::opcode::value::binary); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void FriendAddProcess(const httplib::Request &request, httplib::Response &response) { + //好友申请的处理----- + FriendAddProcessReq req; + FriendAddProcessRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("好友申请处理请求正文反序列化失败!"); + return err_response("好友申请处理请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_friend_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::FriendService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.FriendAddProcess(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 好友子服务调用失败!", req.request_id()); + return err_response("好友子服务调用失败!"); + } + + if (rsp.success()) { + auto process_user_rsp = _GetUserInfo(req.request_id(), *uid); + if (!process_user_rsp) { + LOG_ERROR("{} 获取用户信息失败!", req.request_id()); + return err_response("获取用户信息失败!"); + } + auto apply_user_rsp = _GetUserInfo(req.request_id(), req.apply_user_id()); + if (!process_user_rsp) { + LOG_ERROR("{} 获取用户信息失败!", req.request_id()); + return err_response("获取用户信息失败!"); + } + auto process_conn = _connections->connection(*uid); + if (process_conn) LOG_DEBUG("找到处理人的长连接!"); + else LOG_DEBUG("未找到处理人的长连接!"); + auto apply_conn = _connections->connection(req.apply_user_id()); + if (apply_conn) LOG_DEBUG("找到申请人的长连接!"); + else LOG_DEBUG("未找到申请人的长连接!"); + //4. 将处理结果给申请人进行通知 + if (apply_conn) { + NotifyMessage notify; + notify.set_notify_type(NotifyType::FRIEND_ADD_PROCESS_NOTIFY); + auto process_result = notify.mutable_friend_process_result(); + process_result->mutable_user_info()->CopyFrom(process_user_rsp->user_info()); + process_result->set_agree(req.agree()); + apply_conn->send(notify.SerializeAsString(), + websocketpp::frame::opcode::value::binary); + LOG_DEBUG("对申请人进行申请处理结果通知!"); + } + //5. 若处理结果是同意 --- 会伴随着单聊会话的创建 -- 因此需要对双方进行会话创建的通知 + if (req.agree() && apply_conn) { //对申请人的通知---会话信息就是处理人信息 + NotifyMessage notify; + notify.set_notify_type(NotifyType::CHAT_SESSION_CREATE_NOTIFY); + auto chat_session = notify.mutable_new_chat_session_info(); + chat_session->mutable_chat_session_info()->set_single_chat_friend_id(*uid); + chat_session->mutable_chat_session_info()->set_chat_session_id(rsp.new_session_id()); + chat_session->mutable_chat_session_info()->set_chat_session_name(process_user_rsp->user_info().nickname()); + chat_session->mutable_chat_session_info()->set_avatar(process_user_rsp->user_info().avatar()); + apply_conn->send(notify.SerializeAsString(), websocketpp::frame::opcode::value::binary); + LOG_DEBUG("对申请人进行会话创建通知!"); + } + if (req.agree() && process_conn) { //对处理人的通知 --- 会话信息就是申请人信息 + NotifyMessage notify; + notify.set_notify_type(NotifyType::CHAT_SESSION_CREATE_NOTIFY); + auto chat_session = notify.mutable_new_chat_session_info(); + chat_session->mutable_chat_session_info()->set_single_chat_friend_id(req.apply_user_id()); + chat_session->mutable_chat_session_info()->set_chat_session_id(rsp.new_session_id()); + chat_session->mutable_chat_session_info()->set_chat_session_name(apply_user_rsp->user_info().nickname()); + chat_session->mutable_chat_session_info()->set_avatar(apply_user_rsp->user_info().avatar()); + process_conn->send(notify.SerializeAsString(), websocketpp::frame::opcode::value::binary); + LOG_DEBUG("对处理人进行会话创建通知!"); + } + } + //6. 对客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void FriendRemove(const httplib::Request &request, httplib::Response &response) { + // 1. 正文的反序列化,提取关键要素:登录会话ID + FriendRemoveReq req; + FriendRemoveRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("删除好友请求正文反序列化失败!"); + return err_response("删除好友请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_friend_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::FriendService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.FriendRemove(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 好友子服务调用失败!", req.request_id()); + return err_response("好友子服务调用失败!"); + } + // 4. 若业务处理成功 --- 且获取被申请方长连接成功,则向被申请放进行好友申请事件通知 + auto conn = _connections->connection(req.peer_id()); + if (rsp.success() && conn) { + LOG_ERROR("对被删除人 {} 进行好友删除通知!", req.peer_id()); + NotifyMessage notify; + notify.set_notify_type(NotifyType::FRIEND_REMOVE_NOTIFY); + notify.mutable_friend_remove()->set_user_id(*uid); + conn->send(notify.SerializeAsString(), websocketpp::frame::opcode::value::binary); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void FriendSearch(const httplib::Request &request, httplib::Response &response) { + FriendSearchReq req; + FriendSearchRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("用户搜索请求正文反序列化失败!"); + return err_response("用户搜索请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_friend_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::FriendService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.FriendSearch(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 好友子服务调用失败!", req.request_id()); + return err_response("好友子服务调用失败!"); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void GetPendingFriendEventList(const httplib::Request &request, httplib::Response &response) { + GetPendingFriendEventListReq req; + GetPendingFriendEventListRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("获取待处理好友申请请求正文反序列化失败!"); + return err_response("获取待处理好友申请请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_friend_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::FriendService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.GetPendingFriendEventList(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 好友子服务调用失败!", req.request_id()); + return err_response("好友子服务调用失败!"); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void GetChatSessionList(const httplib::Request &request, httplib::Response &response) { + GetChatSessionListReq req; + GetChatSessionListRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("获取聊天会话列表请求正文反序列化失败!"); + return err_response("获取聊天会话列表请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_friend_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::FriendService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.GetChatSessionList(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 好友子服务调用失败!", req.request_id()); + return err_response("好友子服务调用失败!"); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void GetChatSessionMember(const httplib::Request &request, httplib::Response &response) { + GetChatSessionMemberReq req; + GetChatSessionMemberRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("获取聊天会话成员请求正文反序列化失败!"); + return err_response("获取聊天会话成员请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_friend_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::FriendService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.GetChatSessionMember(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 好友子服务调用失败!", req.request_id()); + return err_response("好友子服务调用失败!"); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void ChatSessionCreate(const httplib::Request &request, httplib::Response &response) { + ChatSessionCreateReq req; + ChatSessionCreateRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("创建聊天会话请求正文反序列化失败!"); + return err_response("创建聊天会话请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_friend_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::FriendService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.ChatSessionCreate(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 好友子服务调用失败!", req.request_id()); + return err_response("好友子服务调用失败!"); + } + // 4. 若业务处理成功 --- 且获取被申请方长连接成功,则向被申请放进行好友申请事件通知 + if (rsp.success()){ + for (int i = 0; i < req.member_id_list_size(); i++) { + auto conn = _connections->connection(req.member_id_list(i)); + if (!conn) { + LOG_DEBUG("未找到群聊成员 {} 长连接", req.member_id_list(i)); + continue; + } + NotifyMessage notify; + notify.set_notify_type(NotifyType::CHAT_SESSION_CREATE_NOTIFY); + auto chat_session = notify.mutable_new_chat_session_info(); + chat_session->mutable_chat_session_info()->CopyFrom(rsp.chat_session_info()); + conn->send(notify.SerializeAsString(), websocketpp::frame::opcode::value::binary); + LOG_DEBUG("对群聊成员 {} 进行会话创建通知", req.member_id_list(i)); + } + } + // 5. 向客户端进行响应 + rsp.clear_chat_session_info(); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void GetHistoryMsg(const httplib::Request &request, httplib::Response &response) { + GetHistoryMsgReq req; + GetHistoryMsgRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("获取区间消息请求正文反序列化失败!"); + return err_response("获取区间消息请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_message_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::MsgStorageService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.GetHistoryMsg(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 消息存储子服务调用失败!", req.request_id()); + return err_response("消息存储子服务调用失败!"); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void GetRecentMsg(const httplib::Request &request, httplib::Response &response) { + GetRecentMsgReq req; + GetRecentMsgRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("获取最近消息请求正文反序列化失败!"); + return err_response("获取最近消息请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_message_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::MsgStorageService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.GetRecentMsg(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 消息存储子服务调用失败!", req.request_id()); + return err_response("消息存储子服务调用失败!"); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void MsgSearch(const httplib::Request &request, httplib::Response &response) { + MsgSearchReq req; + MsgSearchRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("消息搜索请求正文反序列化失败!"); + return err_response("消息搜索请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_message_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::MsgStorageService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.MsgSearch(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 消息存储子服务调用失败!", req.request_id()); + return err_response("消息存储子服务调用失败!"); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void GetSingleFile(const httplib::Request &request, httplib::Response &response) { + GetSingleFileReq req; + GetSingleFileRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("单文件下载请求正文反序列化失败!"); + return err_response("单文件下载请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_file_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::FileService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.GetSingleFile(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 文件存储子服务调用失败!", req.request_id()); + return err_response("文件存储子服务调用失败!"); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void GetMultiFile(const httplib::Request &request, httplib::Response &response) { + GetMultiFileReq req; + GetMultiFileRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("单文件下载请求正文反序列化失败!"); + return err_response("单文件下载请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_file_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::FileService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.GetMultiFile(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 文件存储子服务调用失败!", req.request_id()); + return err_response("文件存储子服务调用失败!"); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void PutSingleFile(const httplib::Request &request, httplib::Response &response) { + PutSingleFileReq req; + PutSingleFileRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("单文件上传请求正文反序列化失败!"); + return err_response("单文件上传请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_file_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::FileService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.PutSingleFile(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 文件存储子服务调用失败!", req.request_id()); + return err_response("文件存储子服务调用失败!"); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void PutMultiFile(const httplib::Request &request, httplib::Response &response) { + PutMultiFileReq req; + PutMultiFileRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("批量文件上传请求正文反序列化失败!"); + return err_response("批量文件上传请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_file_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::FileService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.PutMultiFile(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 文件存储子服务调用失败!", req.request_id()); + return err_response("文件存储子服务调用失败!"); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + void SpeechRecognition(const httplib::Request &request, httplib::Response &response) { + LOG_DEBUG("收到语音转文字请求!"); + SpeechRecognitionReq req; + SpeechRecognitionRsp rsp; + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("语音识别请求正文反序列化失败!"); + return err_response("语音识别请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_speech_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::SpeechService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.SpeechRecognition(&cntl, &req, &rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 语音识别子服务调用失败!", req.request_id()); + return err_response("语音识别子服务调用失败!"); + } + // 5. 向客户端进行响应 + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + + void NewMessage(const httplib::Request &request, httplib::Response &response) { + NewMessageReq req; + NewMessageRsp rsp;//这是给客户端的响应 + GetTransmitTargetRsp target_rsp;//这是请求子服务的响应 + auto err_response = [&req, &rsp, &response](const std::string &errmsg) -> void { + rsp.set_success(false); + rsp.set_errmsg(errmsg); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + }; + bool ret = req.ParseFromString(request.body); + if (ret == false) { + LOG_ERROR("新消息请求正文反序列化失败!"); + return err_response("新消息请求正文反序列化失败!"); + } + // 2. 客户端身份识别与鉴权 + std::string ssid = req.session_id(); + auto uid = _redis_session->uid(ssid); + if (!uid) { + LOG_ERROR("{} 获取登录会话关联用户信息失败!", ssid); + return err_response("获取登录会话关联用户信息失败!"); + } + req.set_user_id(*uid); + // 3. 将请求转发给好友子服务进行业务处理 + auto channel = _mm_channels->choose(_transmite_service_name); + if (!channel) { + LOG_ERROR("{} 未找到可提供业务处理的用户子服务节点!", req.request_id()); + return err_response("未找到可提供业务处理的用户子服务节点!"); + } + bite_im::MsgTransmitService_Stub stub(channel.get()); + brpc::Controller cntl; + stub.GetTransmitTarget(&cntl, &req, &target_rsp, nullptr); + if (cntl.Failed()) { + LOG_ERROR("{} 消息转发子服务调用失败!", req.request_id()); + return err_response("消息转发子服务调用失败!"); + } + // 4. 若业务处理成功 --- 且获取被申请方长连接成功,则向被申请放进行好友申请事件通知 + if (target_rsp.success()){ + for (int i = 0; i < target_rsp.target_id_list_size(); i++) { + std::string notify_uid = target_rsp.target_id_list(i); + if (notify_uid == *uid) continue; //不通知自己 + auto conn = _connections->connection(notify_uid); + if (!conn) { continue;} + NotifyMessage notify; + notify.set_notify_type(NotifyType::CHAT_MESSAGE_NOTIFY); + auto msg_info = notify.mutable_new_message_info(); + msg_info->mutable_message_info()->CopyFrom(target_rsp.message()); + conn->send(notify.SerializeAsString(), websocketpp::frame::opcode::value::binary); + } + } + // 5. 向客户端进行响应 + rsp.set_request_id(req.request_id()); + rsp.set_success(target_rsp.success()); + rsp.set_errmsg(target_rsp.errmsg()); + response.set_content(rsp.SerializeAsString(), "application/x-protbuf"); + } + private: + Session::ptr _redis_session; + Status::ptr _redis_status; + + std::string _user_service_name; + std::string _file_service_name; + std::string _speech_service_name; + std::string _message_service_name; + std::string _transmite_service_name; + std::string _friend_service_name; + ServiceManager::ptr _mm_channels; + Discovery::ptr _service_discoverer; + + Connection::ptr _connections; + + server_t _ws_server; + httplib::Server _http_server; + std::thread _http_thread; + }; + + class GatewayServerBuilder { + public: + //构造redis客户端对象 + void make_redis_object(const std::string &host, + int port, + int db, + bool keep_alive) { + _redis_client = RedisClientFactory::create(host, port, db, keep_alive); + } + //用于构造服务发现客户端&信道管理对象 + void make_discovery_object(const std::string ®_host, + const std::string &base_service_name, + const std::string &file_service_name, + const std::string &speech_service_name, + const std::string &message_service_name, + const std::string &friend_service_name, + const std::string &user_service_name, + const std::string &transmite_service_name) { + _file_service_name = file_service_name; + _speech_service_name = speech_service_name; + _message_service_name = message_service_name; + _friend_service_name = friend_service_name; + _user_service_name = user_service_name; + _transmite_service_name = transmite_service_name; + _mm_channels = std::make_shared(); + _mm_channels->declared(file_service_name); + _mm_channels->declared(speech_service_name); + _mm_channels->declared(message_service_name); + _mm_channels->declared(friend_service_name); + _mm_channels->declared(user_service_name); + _mm_channels->declared(transmite_service_name); + auto put_cb = std::bind(&ServiceManager::onServiceOnline, _mm_channels.get(), std::placeholders::_1, std::placeholders::_2); + auto del_cb = std::bind(&ServiceManager::onServiceOffline, _mm_channels.get(), std::placeholders::_1, std::placeholders::_2); + _service_discoverer = std::make_shared(reg_host, base_service_name, put_cb, del_cb); + } + void make_server_object(int websocket_port, int http_port) { + _websocket_port = websocket_port; + _http_port = http_port; + } + //构造RPC服务器对象 + GatewayServer::ptr build() { + if (!_redis_client) { + LOG_ERROR("还未初始化Redis客户端模块!"); + abort(); + } + if (!_service_discoverer) { + LOG_ERROR("还未初始化服务发现模块!"); + abort(); + } + if (!_mm_channels) { + LOG_ERROR("还未初始化信道管理模块!"); + abort(); + } + GatewayServer::ptr server = std::make_shared( + _websocket_port, _http_port, _redis_client, _mm_channels, + _service_discoverer, _user_service_name, _file_service_name, + _speech_service_name, _message_service_name, + _transmite_service_name, _friend_service_name); + return server; + } + private: + int _websocket_port; + int _http_port; + + std::shared_ptr _redis_client; + + std::string _file_service_name; + std::string _speech_service_name; + std::string _message_service_name; + std::string _friend_service_name; + std::string _user_service_name; + std::string _transmite_service_name; + ServiceManager::ptr _mm_channels; + Discovery::ptr _service_discoverer; + }; +} \ No newline at end of file diff --git a/message/CMakeLists.txt b/message/CMakeLists.txt new file mode 100644 index 0000000..4b3893c --- /dev/null +++ b/message/CMakeLists.txt @@ -0,0 +1,89 @@ +# 1. 添加cmake版本说明 +cmake_minimum_required(VERSION 3.1.3) +# 2. 声明工程名称 +project(message_server) + +set(target "message_server") + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") + +# 3. 检测并生成ODB框架代码 +# 1. 添加所需的proto映射代码文件名称 +set(proto_path ${CMAKE_CURRENT_SOURCE_DIR}/../proto) +set(proto_files base.proto user.proto file.proto message.proto) +# 2. 检测框架代码文件是否已经生成 +set(proto_hxx "") +set(proto_cxx "") +set(proto_srcs "") +foreach(proto_file ${proto_files}) +# 3. 如果没有生成,则预定义生成指令 -- 用于在构建项目之间先生成框架代码 + string(REPLACE ".proto" ".pb.cc" proto_cc ${proto_file}) + string(REPLACE ".proto" ".pb.h" proto_hh ${proto_file}) + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}${proto_cc}) + add_custom_command( + PRE_BUILD + COMMAND protoc + ARGS --cpp_out=${CMAKE_CURRENT_BINARY_DIR} -I ${proto_path} --experimental_allow_proto3_optional ${proto_path}/${proto_file} + DEPENDS ${proto_path}/${proto_file} + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + COMMENT "生成Protobuf框架代码文件:" ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + ) + endif() + list(APPEND proto_srcs ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc}) +endforeach() + +# 3. 检测并生成ODB框架代码 +# 1. 添加所需的odb映射代码文件名称 +set(odb_path ${CMAKE_CURRENT_SOURCE_DIR}/../odb) +set(odb_files message.hxx) +# 2. 检测框架代码文件是否已经生成 +set(odb_hxx "") +set(odb_cxx "") +set(odb_srcs "") +foreach(odb_file ${odb_files}) +# 3. 如果没有生成,则预定义生成指令 -- 用于在构建项目之间先生成框架代码 + string(REPLACE ".hxx" "-odb.hxx" odb_hxx ${odb_file}) + string(REPLACE ".hxx" "-odb.cxx" odb_cxx ${odb_file}) + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}${odb_cxx}) + add_custom_command( + PRE_BUILD + COMMAND odb + ARGS -d mysql --std c++11 --generate-query --generate-schema --profile boost/date-time ${odb_path}/${odb_file} + DEPENDS ${odb_path}/${odb_file} + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${odb_cxx} + COMMENT "生成ODB框架代码文件:" ${CMAKE_CURRENT_BINARY_DIR}/${odb_cxx} + ) + endif() +# 4. 将所有生成的框架源码文件名称保存起来 student-odb.cxx classes-odb.cxx + list(APPEND odb_srcs ${CMAKE_CURRENT_BINARY_DIR}/${odb_cxx}) +endforeach() + +# 4. 获取源码目录下的所有源码文件 +set(src_files "") +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/source src_files) +# 5. 声明目标及依赖 +add_executable(${target} ${src_files} ${proto_srcs} ${odb_srcs}) +# 7. 设置需要连接的库 +target_link_libraries(${target} -lgflags + -lspdlog -lfmt -lbrpc -lssl -lcrypto + -lprotobuf -lleveldb -letcd-cpp-api + -lcpprest -lcurl -lodb-mysql -lodb -lodb-boost + /usr/lib/x86_64-linux-gnu/libjsoncpp.so.19 + -lcpr -lelasticlient + -lamqpcpp -lev) + + +set(test_client "message_client") +set(test_files "") +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/test test_files) +add_executable(${test_client} ${test_files} ${proto_srcs}) +target_link_libraries(${test_client} -pthread -lgtest -lgflags -lspdlog -lfmt -lbrpc -lssl -lcrypto -lprotobuf -lleveldb -letcd-cpp-api -lcpprest -lcurl /usr/lib/x86_64-linux-gnu/libjsoncpp.so.19) + +# 6. 设置头文件默认搜索路径 +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../common) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../odb) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third/include) + +#8. 设置安装路径 +INSTALL(TARGETS ${target} ${test_client} RUNTIME DESTINATION bin) \ No newline at end of file diff --git a/message/dockerfile b/message/dockerfile new file mode 100644 index 0000000..e5a9b07 --- /dev/null +++ b/message/dockerfile @@ -0,0 +1,16 @@ +# 声明基础经镜像来源 +FROM debian:12 + +# 声明工作目录 +WORKDIR /im +RUN mkdir -p /im/logs &&\ + mkdir -p /im/data &&\ + mkdir -p /im/conf &&\ + mkdir -p /im/bin + +# 将可执行程序依赖,拷贝进镜像 +COPY ./build/message_server /im/bin/ +# 将可执行程序文件,拷贝进镜像 +COPY ./depends /lib/x86_64-linux-gnu/ +# 设置容器启动的默认操作 ---运行程序 +CMD /im/bin/message_server -flagfile=/im/conf/message_server.conf \ No newline at end of file diff --git a/message/message_server.conf b/message/message_server.conf new file mode 100644 index 0000000..d4f22cc --- /dev/null +++ b/message/message_server.conf @@ -0,0 +1,26 @@ +-run_mode=true +-log_file=/im/logs/message.log +-log_level=0 +-registry_host=http://10.0.0.235:2379 +-instance_name=/message_service/instance +-access_host=10.0.0.235:10005 +-listen_port=10005 +-rpc_timeout=-1 +-rpc_threads=1 +-base_service=/service +-user_service=/service/user_service +-file_service=/service/file_service +-es_host=http://10.0.0.235:9200/ +-mysql_host=10.0.0.235 +-mysql_user=root +-mysql_pswd=123456 +-mysql_db=bite_im +-mysql_cset=utf8 +-mysql_port=0 +-mysql_pool_count=4 +-mq_user=root +-mq_pswd=123456 +-mq_host=10.0.0.235:5672 +-mq_msg_exchange=msg_exchange +-mq_msg_queue=msg_queue +-mq_msg_binding_key=msg_queue \ No newline at end of file diff --git a/message/source/message_server.cc b/message/source/message_server.cc new file mode 100644 index 0000000..0f5b876 --- /dev/null +++ b/message/source/message_server.cc @@ -0,0 +1,56 @@ +//主要实现语音识别子服务的服务器的搭建 +#include "message_server.hpp" + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +DEFINE_string(registry_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(instance_name, "/message_service/instance", "当前实例名称"); +DEFINE_string(access_host, "127.0.0.1:10005", "当前实例的外部访问地址"); + +DEFINE_int32(listen_port, 10005, "Rpc服务器监听端口"); +DEFINE_int32(rpc_timeout, -1, "Rpc调用超时时间"); +DEFINE_int32(rpc_threads, 1, "Rpc的IO线程数量"); + + +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(file_service, "/service/file_service", "文件管理子服务名称"); +DEFINE_string(user_service, "/service/user_service", "用户管理子服务名称"); + +DEFINE_string(es_host, "http://127.0.0.1:9200/", "ES搜索引擎服务器URL"); + +DEFINE_string(mysql_host, "127.0.0.1", "Mysql服务器访问地址"); +DEFINE_string(mysql_user, "root", "Mysql服务器访问用户名"); +DEFINE_string(mysql_pswd, "123456", "Mysql服务器访问密码"); +DEFINE_string(mysql_db, "bite_im", "Mysql默认库名称"); +DEFINE_string(mysql_cset, "utf8", "Mysql客户端字符集"); +DEFINE_int32(mysql_port, 0, "Mysql服务器访问端口"); +DEFINE_int32(mysql_pool_count, 4, "Mysql连接池最大连接数量"); + +DEFINE_string(mq_user, "root", "消息队列服务器访问用户名"); +DEFINE_string(mq_pswd, "123456", "消息队列服务器访问密码"); +DEFINE_string(mq_host, "127.0.0.1:5672", "消息队列服务器访问地址"); +DEFINE_string(mq_msg_exchange, "msg_exchange", "持久化消息的发布交换机名称"); +DEFINE_string(mq_msg_queue, "msg_queue", "持久化消息的发布队列名称"); +DEFINE_string(mq_msg_binding_key, "msg_queue", "持久化消息的发布队列名称"); + + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + bite_im::MessageServerBuilder msb; + msb.make_mq_object(FLAGS_mq_user, FLAGS_mq_pswd, FLAGS_mq_host, + FLAGS_mq_msg_exchange, FLAGS_mq_msg_queue, FLAGS_mq_msg_binding_key); + msb.make_es_object({FLAGS_es_host}); + msb.make_mysql_object(FLAGS_mysql_user, FLAGS_mysql_pswd, FLAGS_mysql_host, + FLAGS_mysql_db, FLAGS_mysql_cset, FLAGS_mysql_port, FLAGS_mysql_pool_count); + msb.make_discovery_object(FLAGS_registry_host, FLAGS_base_service, FLAGS_file_service, FLAGS_user_service); + msb.make_rpc_server(FLAGS_listen_port, FLAGS_rpc_timeout, FLAGS_rpc_threads); + msb.make_registry_object(FLAGS_registry_host, FLAGS_base_service + FLAGS_instance_name, FLAGS_access_host); + auto server = msb.build(); + server->start(); + return 0; +} \ No newline at end of file diff --git a/message/source/message_server.hpp b/message/source/message_server.hpp new file mode 100644 index 0000000..eccebba --- /dev/null +++ b/message/source/message_server.hpp @@ -0,0 +1,579 @@ +//实现语音识别子服务 +#include +#include + +#include "data_es.hpp" // es数据管理客户端封装 +#include "mysql_message.hpp" // mysql数据管理客户端封装 +#include "etcd.hpp" // 服务注册模块封装 +#include "logger.hpp" // 日志模块封装 +#include "utils.hpp" // 基础工具接口 +#include "channel.hpp" // 信道管理模块封装 +#include "rabbitmq.hpp" + +#include "message.pb.h" // protobuf框架代码 +#include "base.pb.h" // protobuf框架代码 +#include "file.pb.h" // protobuf框架代码 +#include "user.pb.h" // protobuf框架代码 + +namespace bite_im{ +class MessageServiceImpl : public bite_im::MsgStorageService { + public: + MessageServiceImpl( + const std::shared_ptr &es_client, + const std::shared_ptr &mysql_client, + const ServiceManager::ptr &channel_manager, + const std::string &file_service_name, + const std::string &user_service_name) : + _es_message(std::make_shared(es_client)), + _mysql_message(std::make_shared(mysql_client)), + _file_service_name(file_service_name), + _user_service_name(user_service_name), + _mm_channels(channel_manager){ + _es_message->createIndex(); + } + ~MessageServiceImpl(){} + virtual void GetHistoryMsg(::google::protobuf::RpcController* controller, + const ::bite_im::GetHistoryMsgReq* request, + ::bite_im::GetHistoryMsgRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //1. 提取关键要素:会话ID,起始时间,结束时间 + std::string rid = request->request_id(); + std::string chat_ssid = request->chat_session_id(); + boost::posix_time::ptime stime = boost::posix_time::from_time_t(request->start_time()); + boost::posix_time::ptime etime = boost::posix_time::from_time_t(request->over_time()); + //2. 从数据库中进行消息查询 + auto msg_lists = _mysql_message->range(chat_ssid, stime, etime); + if (msg_lists.empty()) { + response->set_request_id(rid); + response->set_success(true); + return ; + } + //3. 统计所有文件类型消息的文件ID,并从文件子服务进行批量文件下载 + std::unordered_set file_id_lists; + for (const auto &msg : msg_lists) { + if (msg.file_id().empty()) continue; + LOG_DEBUG("需要下载的文件ID: {}", msg.file_id()); + file_id_lists.insert(msg.file_id()); + } + std::unordered_map file_data_lists; + bool ret = _GetFile(rid, file_id_lists, file_data_lists); + if (ret == false) { + LOG_ERROR("{} 批量文件数据下载失败!", rid); + return err_response(rid, "批量文件数据下载失败!"); + } + //4. 统计所有消息的发送者用户ID,从用户子服务进行批量用户信息获取 + std::unordered_set user_id_lists; // {猪爸爸吧, 祝妈妈,猪爸爸吧,祝爸爸} + for (const auto &msg : msg_lists) { + user_id_lists.insert(msg.user_id()); + } + std::unordered_map user_lists; + ret = _GetUser(rid, user_id_lists, user_lists); + if (ret == false) { + LOG_ERROR("{} 批量用户数据获取失败!", rid); + return err_response(rid, "批量用户数据获取失败!"); + } + //5. 组织响应 + response->set_request_id(rid); + response->set_success(true); + for (const auto &msg : msg_lists) { + auto message_info = response->add_msg_list(); + message_info->set_message_id(msg.message_id()); + message_info->set_chat_session_id(msg.session_id()); + message_info->set_timestamp(boost::posix_time::to_time_t(msg.create_time())); + message_info->mutable_sender()->CopyFrom(user_lists[msg.user_id()]); + switch(msg.message_type()) { + case MessageType::STRING: + message_info->mutable_message()->set_message_type(MessageType::STRING); + message_info->mutable_message()->mutable_string_message()->set_content(msg.content()); + break; + case MessageType::IMAGE: + message_info->mutable_message()->set_message_type(MessageType::IMAGE); + message_info->mutable_message()->mutable_image_message()->set_file_id(msg.file_id()); + message_info->mutable_message()->mutable_image_message()->set_image_content(file_data_lists[msg.file_id()]); + break; + case MessageType::FILE: + message_info->mutable_message()->set_message_type(MessageType::FILE); + message_info->mutable_message()->mutable_file_message()->set_file_id(msg.file_id()); + message_info->mutable_message()->mutable_file_message()->set_file_size(msg.file_size()); + message_info->mutable_message()->mutable_file_message()->set_file_name(msg.file_name()); + message_info->mutable_message()->mutable_file_message()->set_file_contents(file_data_lists[msg.file_id()]); + break; + case MessageType::SPEECH: + message_info->mutable_message()->set_message_type(MessageType::SPEECH); + message_info->mutable_message()->mutable_speech_message()->set_file_id(msg.file_id()); + message_info->mutable_message()->mutable_speech_message()->set_file_contents(file_data_lists[msg.file_id()]); + break; + default: + LOG_ERROR("消息类型错误!!"); + return; + } + } + return; + } + virtual void GetRecentMsg(::google::protobuf::RpcController* controller, + const ::bite_im::GetRecentMsgReq* request, + ::bite_im::GetRecentMsgRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //1. 提取请求中的关键要素:请求ID,会话ID,要获取的消息数量 + std::string rid = request->request_id(); + std::string chat_ssid = request->chat_session_id(); + int msg_count = request->msg_count(); + //2. 从数据库,获取最近的消息元信息 + auto msg_lists = _mysql_message->recent(chat_ssid, msg_count); + if (msg_lists.empty()) { + response->set_request_id(rid); + response->set_success(true); + return ; + } + //3. 统计所有消息中文件类型消息的文件ID列表,从文件子服务下载文件 + std::unordered_set file_id_lists; + for (const auto &msg : msg_lists) { + if (msg.file_id().empty()) continue; + LOG_DEBUG("需要下载的文件ID: {}", msg.file_id()); + file_id_lists.insert(msg.file_id()); + } + std::unordered_map file_data_lists; + bool ret = _GetFile(rid, file_id_lists, file_data_lists); + if (ret == false) { + LOG_ERROR("{} 批量文件数据下载失败!", rid); + return err_response(rid, "批量文件数据下载失败!"); + } + //4. 统计所有消息的发送者用户ID,从用户子服务进行批量用户信息获取 + std::unordered_set user_id_lists; + for (const auto &msg : msg_lists) { + user_id_lists.insert(msg.user_id()); + } + std::unordered_map user_lists; + ret = _GetUser(rid, user_id_lists, user_lists); + if (ret == false) { + LOG_ERROR("{} 批量用户数据获取失败!", rid); + return err_response(rid, "批量用户数据获取失败!"); + } + //5. 组织响应 + response->set_request_id(rid); + response->set_success(true); + for (const auto &msg : msg_lists) { + auto message_info = response->add_msg_list(); + message_info->set_message_id(msg.message_id()); + message_info->set_chat_session_id(msg.session_id()); + message_info->set_timestamp(boost::posix_time::to_time_t(msg.create_time())); + message_info->mutable_sender()->CopyFrom(user_lists[msg.user_id()]); + switch(msg.message_type()) { + case MessageType::STRING: + message_info->mutable_message()->set_message_type(MessageType::STRING); + message_info->mutable_message()->mutable_string_message()->set_content(msg.content()); + break; + case MessageType::IMAGE: + message_info->mutable_message()->set_message_type(MessageType::IMAGE); + message_info->mutable_message()->mutable_image_message()->set_file_id(msg.file_id()); + message_info->mutable_message()->mutable_image_message()->set_image_content(file_data_lists[msg.file_id()]); + break; + case MessageType::FILE: + message_info->mutable_message()->set_message_type(MessageType::FILE); + message_info->mutable_message()->mutable_file_message()->set_file_id(msg.file_id()); + message_info->mutable_message()->mutable_file_message()->set_file_size(msg.file_size()); + message_info->mutable_message()->mutable_file_message()->set_file_name(msg.file_name()); + message_info->mutable_message()->mutable_file_message()->set_file_contents(file_data_lists[msg.file_id()]); + break; + case MessageType::SPEECH: + message_info->mutable_message()->set_message_type(MessageType::SPEECH); + message_info->mutable_message()->mutable_speech_message()->set_file_id(msg.file_id()); + message_info->mutable_message()->mutable_speech_message()->set_file_contents(file_data_lists[msg.file_id()]); + break; + default: + LOG_ERROR("消息类型错误!!"); + return; + } + } + return; + } + virtual void MsgSearch(::google::protobuf::RpcController* controller, + const ::bite_im::MsgSearchReq* request, + ::bite_im::MsgSearchRsp* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard rpc_guard(done); + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //关键字的消息搜索--只针对文本消息 + //1. 从请求中提取关键要素:请求ID,会话ID, 关键字 + std::string rid = request->request_id(); + std::string chat_ssid = request->chat_session_id(); + std::string skey = request->search_key(); + //2. 从ES搜索引擎中进行关键字消息搜索,得到消息列表 + auto msg_lists = _es_message->search(skey, chat_ssid); + if (msg_lists.empty()) { + response->set_request_id(rid); + response->set_success(true); + return ; + } + //3. 组织所有消息的用户ID,从用户子服务获取用户信息 + std::unordered_set user_id_lists; + for (const auto &msg : msg_lists) { + user_id_lists.insert(msg.user_id()); + } + std::unordered_map user_lists; + bool ret = _GetUser(rid, user_id_lists, user_lists); + if (ret == false) { + LOG_ERROR("{} 批量用户数据获取失败!", rid); + return err_response(rid, "批量用户数据获取失败!"); + } + //4. 组织响应 + response->set_request_id(rid); + response->set_success(true); + for (const auto &msg : msg_lists) { + auto message_info = response->add_msg_list(); + message_info->set_message_id(msg.message_id()); + message_info->set_chat_session_id(msg.session_id()); + message_info->set_timestamp(boost::posix_time::to_time_t(msg.create_time())); + message_info->mutable_sender()->CopyFrom(user_lists[msg.user_id()]); + message_info->mutable_message()->set_message_type(MessageType::STRING); + message_info->mutable_message()->mutable_string_message()->set_content(msg.content()); + } + return; + } + + //rabbitmq获得消息 + void onMessage(const char *body, size_t sz) { + LOG_DEBUG("收到新消息,进行存储处理!"); + //1. 取出序列化的消息内容,进行反序列化 + bite_im::MessageInfo message; + bool ret = message.ParseFromArray(body, sz); + if (ret == false) { + LOG_ERROR("对消费到的消息进行反序列化失败!"); + return; + } + //2. 根据不同的消息类型进行不同的处理 + std::string file_id, file_name, content; + int64_t file_size; + switch(message.message().message_type()) { + // 1. 如果是一个文本类型消息,取元信息存储到ES中 + case MessageType::STRING: + content = message.message().string_message().content(); + ret = _es_message->appendData( + message.sender().user_id(), + message.message_id(), + message.timestamp(), + message.chat_session_id(), + content); + if (ret == false) { + LOG_ERROR("文本消息向存储引擎进行存储失败!"); + return; + } + break; + // 2. 如果是一个图片/语音/文件消息,则取出数据存储到文件子服务中,并获取文件ID + case MessageType::IMAGE: + { + const auto &msg = message.message().image_message(); + ret = _PutFile("", msg.image_content(), msg.image_content().size(), file_id); + if (ret == false) { + LOG_ERROR("上传图片到文件子服务失败!"); + return ; + } + } + break; + case MessageType::FILE: + { + const auto &msg = message.message().file_message(); + file_name = msg.file_name(); + file_size = msg.file_size(); + ret = _PutFile(file_name, msg.file_contents(), file_size, file_id); + if (ret == false) { + LOG_ERROR("上传文件到文件子服务失败!"); + return ; + } + } + break; + case MessageType::SPEECH: + { + const auto &msg = message.message().speech_message(); + ret = _PutFile("", msg.file_contents(), msg.file_contents().size(), file_id); + if (ret == false) { + LOG_ERROR("上传语音到文件子服务失败!"); + return ; + } + } + break; + default: + LOG_ERROR("消息类型错误!"); + return; + } + //3. 提取消息的元信息,存储到mysql数据库中 + bite_im::Message msg(message.message_id(), + message.chat_session_id(), + message.sender().user_id(), + message.message().message_type(), + boost::posix_time::from_time_t(message.timestamp())); + msg.content(content); + msg.file_id(file_id); + msg.file_name(file_name); + msg.file_size(file_size); + ret = _mysql_message->insert(msg); + if (ret == false) { + LOG_ERROR("向数据库插入新消息失败!"); + return; + } + } + private: + bool _GetUser(const std::string &rid, + const std::unordered_set &user_id_lists, + std::unordered_map &user_lists) { + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{} 没有可供访问的用户子服务节点!", _user_service_name); + return false; + } + UserService_Stub stub(channel.get()); + GetMultiUserInfoReq req; + GetMultiUserInfoRsp rsp; + req.set_request_id(rid); + for (const auto &id : user_id_lists) { + req.add_users_id(id); + } + brpc::Controller cntl; + stub.GetMultiUserInfo(&cntl, &req, &rsp, nullptr); + if (cntl.Failed() == true || rsp.success() == false) { + LOG_ERROR("用户子服务调用失败:{}!", cntl.ErrorText()); + return false; + } + const auto &umap = rsp.users_info(); + for (auto it = umap.begin(); it != umap.end(); ++it) { + user_lists.insert(std::make_pair(it->first, it->second)); + } + return true; + } + bool _GetFile(const std::string &rid, + const std::unordered_set &file_id_lists, + std::unordered_map &file_data_lists) { + auto channel = _mm_channels->choose(_file_service_name); + if (!channel) { + LOG_ERROR("{} 没有可供访问的文件子服务节点!", _file_service_name); + return false; + } + FileService_Stub stub(channel.get()); + GetMultiFileReq req; + GetMultiFileRsp rsp; + req.set_request_id(rid); + for (const auto &id : file_id_lists) { + req.add_file_id_list(id); + } + brpc::Controller cntl; + stub.GetMultiFile(&cntl, &req, &rsp, nullptr); + if (cntl.Failed() == true || rsp.success() == false) { + LOG_ERROR("文件子服务调用失败:{}!", cntl.ErrorText()); + return false; + } + const auto &fmap = rsp.file_data(); + for (auto it = fmap.begin(); it != fmap.end(); ++it) { + file_data_lists.insert(std::make_pair(it->first, it->second.file_content())); + } + return true; + } + bool _PutFile(const std::string &filename, + const std::string &body, + const int64_t fsize, + std::string &file_id) { + //实现文件数据的上传 + auto channel = _mm_channels->choose(_file_service_name); + if (!channel) { + LOG_ERROR("{} 没有可供访问的文件子服务节点!", _file_service_name); + return false; + } + FileService_Stub stub(channel.get()); + PutSingleFileReq req; + PutSingleFileRsp rsp; + req.mutable_file_data()->set_file_name(filename); + req.mutable_file_data()->set_file_size(fsize); + req.mutable_file_data()->set_file_content(body); + brpc::Controller cntl; + stub.PutSingleFile(&cntl, &req, &rsp, nullptr); + if (cntl.Failed() == true || rsp.success() == false) { + LOG_ERROR("文件子服务调用失败:{}!", cntl.ErrorText()); + return false; + } + file_id = rsp.file_info().file_id(); + return true; + } + private: + ESMessage::ptr _es_message; + MessageTable::ptr _mysql_message; + //这边是rpc调用客户端相关对象 + std::string _user_service_name; + std::string _file_service_name; + ServiceManager::ptr _mm_channels; +}; + +class MessageServer { + public: + using ptr = std::shared_ptr; + MessageServer(const MQClient::ptr &mq_client, + const Discovery::ptr service_discoverer, + const Registry::ptr ®_client, + const std::shared_ptr &es_client, + const std::shared_ptr &mysql_client, + const std::shared_ptr &server): + _mq_client(mq_client), + _service_discoverer(service_discoverer), + _registry_client(reg_client), + _es_client(es_client), + _mysql_client(mysql_client), + _rpc_server(server){} + ~MessageServer(){} + //搭建RPC服务器,并启动服务器 + void start() { + _rpc_server->RunUntilAskedToQuit(); + } + private: + Discovery::ptr _service_discoverer; + Registry::ptr _registry_client; + MQClient::ptr _mq_client; + std::shared_ptr _es_client; + std::shared_ptr _mysql_client; + std::shared_ptr _rpc_server; +}; + +class MessageServerBuilder { + public: + //构造es客户端对象 + void make_es_object(const std::vector host_list) { + _es_client = ESClientFactory::create(host_list); + } + //构造mysql客户端对象 + void make_mysql_object( + const std::string &user, + const std::string &pswd, + const std::string &host, + const std::string &db, + const std::string &cset, + int port, + int conn_pool_count) { + _mysql_client = ODBFactory::create(user, pswd, host, db, cset, port, conn_pool_count); + } + //用于构造服务发现客户端&信道管理对象 + void make_discovery_object(const std::string ®_host, + const std::string &base_service_name, + const std::string &file_service_name, + const std::string &user_service_name) { + _user_service_name = user_service_name; + _file_service_name = file_service_name; + _mm_channels = std::make_shared(); + _mm_channels->declared(file_service_name); + _mm_channels->declared(user_service_name); + LOG_DEBUG("设置文件子服务为需添加管理的子服务:{}", file_service_name); + auto put_cb = std::bind(&ServiceManager::onServiceOnline, _mm_channels.get(), std::placeholders::_1, std::placeholders::_2); + auto del_cb = std::bind(&ServiceManager::onServiceOffline, _mm_channels.get(), std::placeholders::_1, std::placeholders::_2); + _service_discoverer = std::make_shared(reg_host, base_service_name, put_cb, del_cb); + } + //用于构造服务注册客户端对象 + void make_registry_object(const std::string ®_host, + const std::string &service_name, + const std::string &access_host) { + _registry_client = std::make_shared(reg_host); + _registry_client->registry(service_name, access_host); + } + //用于构造消息队列客户端对象 + void make_mq_object(const std::string &user, + const std::string &passwd, + const std::string &host, + const std::string &exchange_name, + const std::string &queue_name, + const std::string &binding_key) { + _exchange_name = exchange_name; + _queue_name = queue_name; + _mq_client = std::make_shared(user, passwd, host); + _mq_client->declareComponents(exchange_name, queue_name, binding_key); + } + void make_rpc_server(uint16_t port, int32_t timeout, uint8_t num_threads) { + if (!_es_client) { + LOG_ERROR("还未初始化ES搜索引擎模块!"); + abort(); + } + if (!_mysql_client) { + LOG_ERROR("还未初始化Mysql数据库模块!"); + abort(); + } + if (!_mm_channels) { + LOG_ERROR("还未初始化信道管理模块!"); + abort(); + } + _rpc_server = std::make_shared(); + + MessageServiceImpl *msg_service = new MessageServiceImpl(_es_client, + _mysql_client, _mm_channels, _file_service_name, _user_service_name); + int ret = _rpc_server->AddService(msg_service, + brpc::ServiceOwnership::SERVER_OWNS_SERVICE); + if (ret == -1) { + LOG_ERROR("添加Rpc服务失败!"); + abort(); + } + brpc::ServerOptions options; + options.idle_timeout_sec = timeout; + options.num_threads = num_threads; + ret = _rpc_server->Start(port, &options); + if (ret == -1) { + LOG_ERROR("服务启动失败!"); + abort(); + } + + auto callback = std::bind(&MessageServiceImpl::onMessage, msg_service, + std::placeholders::_1, std::placeholders::_2); + _mq_client->consume(_queue_name, callback); + } + //构造RPC服务器对象 + MessageServer::ptr build() { + if (!_service_discoverer) { + LOG_ERROR("还未初始化服务发现模块!"); + abort(); + } + if (!_registry_client) { + LOG_ERROR("还未初始化服务注册模块!"); + abort(); + } + if (!_rpc_server) { + LOG_ERROR("还未初始化RPC服务器模块!"); + abort(); + } + + MessageServer::ptr server = std::make_shared( + _mq_client, _service_discoverer, _registry_client, + _es_client, _mysql_client, _rpc_server); + return server; + } + private: + Registry::ptr _registry_client; + + std::shared_ptr _es_client; + std::shared_ptr _mysql_client; + + std::string _user_service_name; + std::string _file_service_name; + ServiceManager::ptr _mm_channels; + Discovery::ptr _service_discoverer; + + std::string _exchange_name; + std::string _queue_name; + MQClient::ptr _mq_client; + + std::shared_ptr _rpc_server; +}; +} \ No newline at end of file diff --git a/message/test/es_test/main.cc b/message/test/es_test/main.cc new file mode 100644 index 0000000..c0da2cb --- /dev/null +++ b/message/test/es_test/main.cc @@ -0,0 +1,34 @@ +#include "../../../common/data_es.hpp" +#include + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + + +DEFINE_string(es_host, "http://127.0.0.1:9200/", "es服务器URL"); + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + auto es_client = bite_im::ESClientFactory::create({FLAGS_es_host}); + + auto es_msg = std::make_shared(es_client); + // es_msg->createIndex(); + // es_msg->appendData("用户ID1", "消息ID1", 1723025035, "会话ID1", "吃饭了吗?"); + // es_msg->appendData("用户ID2", "消息ID2", 1723025035 - 100, "会话ID1", "吃的盖浇饭!"); + // es_msg->appendData("用户ID3", "消息ID3", 1723025035, "会话ID2", "吃饭了吗?"); + // es_msg->appendData("用户ID4", "消息ID4", 1723025035 - 100, "会话ID2", "吃的盖浇饭!"); + auto res = es_msg->search("盖浇", "会话ID1"); + for (auto &u : res) { + std::cout << "-----------------" << std::endl; + std::cout << u.user_id() << std::endl; + std::cout << u.message_id() << std::endl; + std::cout << u.session_id() << std::endl; + std::cout << boost::posix_time::to_simple_string(u.create_time()) << std::endl; + std::cout << u.content() << std::endl; + } + return 0; +} \ No newline at end of file diff --git a/message/test/es_test/请求格式规范.txt b/message/test/es_test/请求格式规范.txt new file mode 100644 index 0000000..e6bf9ed --- /dev/null +++ b/message/test/es_test/请求格式规范.txt @@ -0,0 +1,72 @@ +POST /message/_doc +{ + "settings" : { + "analysis" : { + "analyzer" : { + "ik" : { + "tokenizer" : "ik_max_word" + } + } + } + }, + "mappings" : { + "dynamic" : true, + "properties" : { + "chat_session_id" : { + "type" : "keyword", + "analyzer" : "standard" + }, + "message_id" : { + "type" : "keyword", + "analyzer" : "standard" + }, + "content" : { + "type" : "text", + "analyzer" : "ik_max_word" + } + } + } +} + + +GET /message/_doc/_search?pretty +{ + "query": { + "match_all": {} + } +} + + +POST /message/_doc/_bulk +{"index":{"_id":"1"}} +{"chat_session_id" : "会话ID1","message_id" : "消息ID1","content" : "吃饭了么?"} +{"index":{"_id":"2"}} +{"chat_session_id" : "会话ID1","message_id" : "消息ID2","content" : "吃的盖浇饭。"} +{"index":{"_id":"3"}} +{"chat_session_id" : "会话ID2","message_id" : "消息ID3","content" : "昨天吃饭了么?"} +{"index":{"_id":"4"}} +{"chat_session_id" : "会话ID2","message_id" : "消息ID4","content" : "昨天吃的盖浇饭。"} + + +GET /message/_doc/_search?pretty +{ + "query": { + "bool": { + "must": [ + { + "term": { + "chat_session_id.keyword": "会话ID1" + } + }, + { + "match": { + "content": "盖浇饭" + } + } + ] + } + } +} + + +DELETE /message diff --git a/message/test/message_client.cc b/message/test/message_client.cc new file mode 100644 index 0000000..e832809 --- /dev/null +++ b/message/test/message_client.cc @@ -0,0 +1,170 @@ +#include "etcd.hpp" +#include "channel.hpp" +#include "utils.hpp" +#include +#include +#include +#include +#include "message.pb.h" +#include "base.pb.h" +#include "user.pb.h" + + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +DEFINE_string(etcd_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(message_service, "/service/message_service", "服务监控根目录"); + +bite_im::ServiceManager::ptr sm; + +void range_test(const std::string &ssid, + const boost::posix_time::ptime &stime, + const boost::posix_time::ptime &etime) { + auto channel = sm->choose(FLAGS_message_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::MsgStorageService_Stub stub(channel.get()); + bite_im::GetHistoryMsgReq req; + bite_im::GetHistoryMsgRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_chat_session_id(ssid); + req.set_start_time(boost::posix_time::to_time_t(stime)); + req.set_over_time(boost::posix_time::to_time_t(etime)); + brpc::Controller cntl; + stub.GetHistoryMsg(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); + + std::cout << rsp.msg_list_size() << std::endl; + + for (int i = 0; i < rsp.msg_list_size(); i++) { + std::cout << "-----------------------获取时间区间消息--------------------------\n"; + auto msg = rsp.msg_list(i); + std::cout << msg.message_id() << std::endl; + std::cout << msg.chat_session_id() << std::endl; + std::cout << boost::posix_time::to_simple_string(boost::posix_time::from_time_t(msg.timestamp())) << std::endl; + std::cout << msg.sender().user_id() << std::endl; + std::cout << msg.sender().nickname() << std::endl; + std::cout << msg.sender().avatar() << std::endl; + if (msg.message().message_type() == bite_im::MessageType::STRING) { + std::cout << "文本消息:" << msg.message().string_message().content() << std::endl; + }else if (msg.message().message_type() == bite_im::MessageType::IMAGE) { + std::cout << "图片消息:" << msg.message().image_message().image_content() << std::endl; + }else if (msg.message().message_type() == bite_im::MessageType::FILE) { + std::cout << "文件消息:" << msg.message().file_message().file_contents() << std::endl; + std::cout << "文件名称:" << msg.message().file_message().file_name() << std::endl; + }else if (msg.message().message_type() == bite_im::MessageType::SPEECH) { + std::cout << "语音消息:" << msg.message().speech_message().file_contents() << std::endl; + }else { + std::cout << "类型错误!!\n"; + } + } +} + +void recent_test(const std::string &ssid, int count) { + auto channel = sm->choose(FLAGS_message_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::MsgStorageService_Stub stub(channel.get()); + bite_im::GetRecentMsgReq req; + bite_im::GetRecentMsgRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_chat_session_id(ssid); + req.set_msg_count(count); + brpc::Controller cntl; + stub.GetRecentMsg(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); + for (int i = 0; i < rsp.msg_list_size(); i++) { + std::cout << "----------------------获取最近消息---------------------------\n"; + auto msg = rsp.msg_list(i); + std::cout << msg.message_id() << std::endl; + std::cout << msg.chat_session_id() << std::endl; + std::cout << boost::posix_time::to_simple_string(boost::posix_time::from_time_t(msg.timestamp())) << std::endl; + std::cout << msg.sender().user_id() << std::endl; + std::cout << msg.sender().nickname() << std::endl; + std::cout << msg.sender().avatar() << std::endl; + if (msg.message().message_type() == bite_im::MessageType::STRING) { + std::cout << "文本消息:" << msg.message().string_message().content() << std::endl; + }else if (msg.message().message_type() == bite_im::MessageType::IMAGE) { + std::cout << "图片消息:" << msg.message().image_message().image_content() << std::endl; + }else if (msg.message().message_type() == bite_im::MessageType::FILE) { + std::cout << "文件消息:" << msg.message().file_message().file_contents() << std::endl; + std::cout << "文件名称:" << msg.message().file_message().file_name() << std::endl; + }else if (msg.message().message_type() == bite_im::MessageType::SPEECH) { + std::cout << "语音消息:" << msg.message().speech_message().file_contents() << std::endl; + }else { + std::cout << "类型错误!!\n"; + } + } +} + + +void search_test(const std::string &ssid, const std::string &key) { + auto channel = sm->choose(FLAGS_message_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::MsgStorageService_Stub stub(channel.get()); + bite_im::MsgSearchReq req; + bite_im::MsgSearchRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_chat_session_id(ssid); + req.set_search_key(key); + brpc::Controller cntl; + stub.MsgSearch(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); + for (int i = 0; i < rsp.msg_list_size(); i++) { + std::cout << "----------------------关键字搜索消息---------------------------\n"; + auto msg = rsp.msg_list(i); + std::cout << msg.message_id() << std::endl; + std::cout << msg.chat_session_id() << std::endl; + std::cout << boost::posix_time::to_simple_string(boost::posix_time::from_time_t(msg.timestamp())) << std::endl; + std::cout << msg.sender().user_id() << std::endl; + std::cout << msg.sender().nickname() << std::endl; + std::cout << msg.sender().avatar() << std::endl; + if (msg.message().message_type() == bite_im::MessageType::STRING) { + std::cout << "文本消息:" << msg.message().string_message().content() << std::endl; + }else if (msg.message().message_type() == bite_im::MessageType::IMAGE) { + std::cout << "图片消息:" << msg.message().image_message().image_content() << std::endl; + }else if (msg.message().message_type() == bite_im::MessageType::FILE) { + std::cout << "文件消息:" << msg.message().file_message().file_contents() << std::endl; + std::cout << "文件名称:" << msg.message().file_message().file_name() << std::endl; + }else if (msg.message().message_type() == bite_im::MessageType::SPEECH) { + std::cout << "语音消息:" << msg.message().speech_message().file_contents() << std::endl; + }else { + std::cout << "类型错误!!\n"; + } + } +} + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + + //1. 先构造Rpc信道管理对象 + sm = std::make_shared(); + sm->declared(FLAGS_message_service); + auto put_cb = std::bind(&bite_im::ServiceManager::onServiceOnline, sm.get(), std::placeholders::_1, std::placeholders::_2); + auto del_cb = std::bind(&bite_im::ServiceManager::onServiceOffline, sm.get(), std::placeholders::_1, std::placeholders::_2); + //2. 构造服务发现对象 + bite_im::Discovery::ptr dclient = std::make_shared(FLAGS_etcd_host, FLAGS_base_service, put_cb, del_cb); + + boost::posix_time::ptime stime(boost::posix_time::time_from_string("2000-08-02 00:00:00")); + boost::posix_time::ptime etime(boost::posix_time::time_from_string("2050-08-09 00:00:00")); + range_test("会话ID1", stime, etime); + recent_test("会话ID1", 2); + search_test("会话ID1", "盖浇"); + return 0; +} \ No newline at end of file diff --git a/message/test/mysql_test/main.cc b/message/test/mysql_test/main.cc new file mode 100644 index 0000000..b51e170 --- /dev/null +++ b/message/test/mysql_test/main.cc @@ -0,0 +1,63 @@ +#include "mysql_message.hpp" +#include + + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +void insert_test(bite_im::MessageTable &tb) { + bite_im::Message m1("消息ID1", "会话ID1", "用户ID1", 0, boost::posix_time::time_from_string("2002-01-20 23:59:59.000")); + tb.insert(m1); + bite_im::Message m2("消息ID2", "会话ID1", "用户ID2", 0, boost::posix_time::time_from_string("2002-01-21 23:59:59.000")); + tb.insert(m2); + bite_im::Message m3("消息ID3", "会话ID1", "用户ID3", 0, boost::posix_time::time_from_string("2002-01-22 23:59:59.000")); + tb.insert(m3); + + //另一个会话 + bite_im::Message m4("消息ID4", "会话ID2", "用户ID4", 0, boost::posix_time::time_from_string("2002-01-20 23:59:59.000")); + tb.insert(m4); + bite_im::Message m5("消息ID5", "会话ID2", "用户ID5", 0, boost::posix_time::time_from_string("2002-01-21 23:59:59.000")); + tb.insert(m5); +} +void remove_test(bite_im::MessageTable &tb) { + tb.remove("会话ID2"); +} + +void recent_test(bite_im::MessageTable &tb) { + auto res = tb.recent("会话ID1", 2); + auto begin = res.rbegin(); + auto end = res.rend(); + for (; begin != end; ++begin) { + std::cout << begin->message_id() << std::endl; + std::cout << begin->session_id() << std::endl; + std::cout << begin->user_id() << std::endl; + std::cout << boost::posix_time::to_simple_string(begin->create_time()) << std::endl; + } +} + +void range_test(bite_im::MessageTable &tb) { + boost::posix_time::ptime stime(boost::posix_time::time_from_string("2002-01-20 23:59:59.000")); + boost::posix_time::ptime etime(boost::posix_time::time_from_string("2002-01-21 23:59:59.000")); + auto res = tb.range("会话ID1", stime, etime); + for (const auto &m : res) { + std::cout << m.message_id() << std::endl; + std::cout << m.session_id() << std::endl; + std::cout << m.user_id() << std::endl; + std::cout << boost::posix_time::to_simple_string(m.create_time()) << std::endl; + } +} + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + auto db = bite_im::ODBFactory::create("root", "123456", "127.0.0.1", "bite_im", "utf8", 0, 1); + bite_im::MessageTable tb(db); + // insert_test(tb); + // remove_test(tb); + // recent_test(tb); + // range_test(tb); + return 0; +} \ No newline at end of file diff --git a/odb/chat_session.hxx b/odb/chat_session.hxx new file mode 100644 index 0000000..2e3879b --- /dev/null +++ b/odb/chat_session.hxx @@ -0,0 +1,68 @@ +#pragma once +#include +#include +#include +#include +#include "chat_session_member.hxx" + +namespace bite_im { + +enum class ChatSessionType { + SINGLE = 1, + GROUP = 2 +}; + +#pragma db object table("chat_session") +class ChatSession { + public: + ChatSession(){} + ChatSession(const std::string &ssid, + const std::string &ssname, const ChatSessionType sstype): + _chat_session_id(ssid), + _chat_session_name(ssname), + _chat_session_type(sstype){} + + std::string chat_session_id() const { return _chat_session_id; } + void chat_session_id(std::string &ssid) { _chat_session_id = ssid; } + + std::string chat_session_name() const { return _chat_session_name; } + void chat_session_name(std::string &ssname) { _chat_session_name = ssname; } + + ChatSessionType chat_session_type() const { return _chat_session_type; } + void chat_session_type(ChatSessionType val) { _chat_session_type = val; } + private: + friend class odb::access; + #pragma db id auto + unsigned long _id; + #pragma db type("varchar(64)") index unique + std::string _chat_session_id; + #pragma db type("varchar(64)") + std::string _chat_session_name; + #pragma db type("tinyint") + ChatSessionType _chat_session_type; //1-单聊; 2-群聊 +}; + +// 这里条件必须是指定条件: css::chat_session_type==1 && csm1.user_id=uid && csm2.user_id != csm1.user_id +#pragma db view object(ChatSession = css)\ + object(ChatSessionMember = csm1 : css::_chat_session_id == csm1::_session_id)\ + object(ChatSessionMember = csm2 : css::_chat_session_id == csm2::_session_id)\ + query((?)) +struct SingleChatSession { + #pragma db column(css::_chat_session_id) + std::string chat_session_id; + #pragma db column(csm2::_user_id) + std::string friend_id; +}; + +// 这里条件必须是指定条件: css::chat_session_type==2 && csm.user_id=uid +#pragma db view object(ChatSession = css)\ + object(ChatSessionMember = csm : css::_chat_session_id == csm::_session_id)\ + query((?)) +struct GroupChatSession { + #pragma db column(css::_chat_session_id) + std::string chat_session_id; + #pragma db column(css::_chat_session_name) + std::string chat_session_name; +}; + +} \ No newline at end of file diff --git a/odb/chat_session_member.hxx b/odb/chat_session_member.hxx new file mode 100644 index 0000000..db801ca --- /dev/null +++ b/odb/chat_session_member.hxx @@ -0,0 +1,32 @@ +#pragma once +#include +#include +#include + +namespace bite_im { +#pragma db object table("chat_session_member") + class ChatSessionMember { + public: + ChatSessionMember(){} + ChatSessionMember(const std::string& ssid, const std::string &uid) + :_session_id(ssid), _user_id(uid) {} + ~ChatSessionMember(){} + + std::string session_id() const {return _session_id; } + void session_id(std::string &ssid) { _session_id = ssid; } + + std::string user_id() const { return _user_id; } + void user_id(std::string &uid) {_user_id = uid; } + + private: + friend class odb::access; + #pragma db id auto + unsigned long _id; + #pragma db type("varchar(64)") index + std::string _session_id; + #pragma db type("varchar(64)") + std::string _user_id; + }; +} + +//odb -d mysql --generate-query --generate-schema --profile boost/date-time person.hxx \ No newline at end of file diff --git a/odb/friend_apply.hxx b/odb/friend_apply.hxx new file mode 100644 index 0000000..be9425d --- /dev/null +++ b/odb/friend_apply.hxx @@ -0,0 +1,34 @@ +#pragma once +#include +#include +#include + +namespace bite_im { +#pragma db object table("friend_apply") +class FriendApply{ + public: + FriendApply() {} + FriendApply(const std::string &eid, + const std::string &uid, const std::string &pid): + _user_id(uid), _peer_id(pid), _event_id(eid){} + + std::string event_id() const { return _event_id; } + void event_id(std::string &eid) { _event_id = eid; } + + std::string user_id() const { return _user_id; } + void user_id(std::string &uid) { _user_id = uid; } + + std::string peer_id() const { return _peer_id; } + void peer_id(std::string &uid) { _peer_id = uid; } + private: + friend class odb::access; + #pragma db id auto + unsigned long _id; + #pragma db type("varchar(64)") index unique + std::string _event_id; + #pragma db type("varchar(64)") index + std::string _user_id; + #pragma db type("varchar(64)") index + std::string _peer_id; +}; +} \ No newline at end of file diff --git a/odb/message.hxx b/odb/message.hxx new file mode 100644 index 0000000..90def46 --- /dev/null +++ b/odb/message.hxx @@ -0,0 +1,83 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace bite_im { +#pragma db object table("message") +class Message { + public: + Message(){} + Message(const std::string &mid, + const std::string &ssid, + const std::string &uid, + const unsigned char mtype, + const boost::posix_time::ptime &ctime): + _message_id(mid), _session_id(ssid), + _user_id(uid), _message_type(mtype), + _create_time(ctime){} + + std::string message_id() const { return _message_id; } + void message_id(const std::string &val) { _message_id = val; } + + std::string session_id() const { return _session_id; } + void session_id(const std::string &val) { _session_id = val; } + + std::string user_id() const { return _user_id; } + void user_id(const std::string &val) { _user_id = val; } + + unsigned char message_type() const { return _message_type; } + void message_type(unsigned char val) { _message_type = val; } + + boost::posix_time::ptime create_time() const { return _create_time; } + void create_time(const boost::posix_time::ptime &val) { _create_time = val; } + + std::string content() const { + if (!_content) return std::string(); + return *_content; + } + void content(const std::string &val) { _content = val; } + + std::string file_id() const { + if (!_file_id) return std::string(); + return *_file_id; + } + void file_id(const std::string &val) { _file_id = val; } + + std::string file_name() const { + if (!_file_name) return std::string(); + return *_file_name; + } + void file_name(const std::string &val) { _file_name = val; } + + unsigned int file_size() const { + if (!_file_size) return 0; + return *_file_size; + } + void file_size(unsigned int val) { _file_size = val; } + private: + friend class odb::access; + #pragma db id auto + unsigned long _id; + #pragma db type("varchar(64)") index unique + std::string _message_id; + #pragma db type("varchar(64)") index + std::string _session_id; //所属会话ID + #pragma db type("varchar(64)") + std::string _user_id; //发送者用户ID + unsigned char _message_type; //消息类型 0-文本;1-图片;2-文件;3-语音 + #pragma db type("TIMESTAMP") + boost::posix_time::ptime _create_time; //消息的产生时间 + + //可空信息字段 + odb::nullable _content; //文本消息内容--非文本消息可以忽略 + #pragma db type("varchar(64)") + odb::nullable _file_id; //文件消息的文件ID -- 文本消息忽略 + #pragma db type("varchar(128)") + odb::nullable _file_name; //文件消息的文件名称 -- 只针对文件消息有效 + odb::nullable _file_size; //文件消息的文件大小 -- 只针对文件消息有效 +}; +//odb -d mysql --std c++11 --generate-query --generate-schema --profile boost/date-time message.hxx +} \ No newline at end of file diff --git a/odb/relation.hxx b/odb/relation.hxx new file mode 100644 index 0000000..e1dbc82 --- /dev/null +++ b/odb/relation.hxx @@ -0,0 +1,30 @@ +#pragma once +#include +#include +#include +#include + +namespace bite_im { +#pragma db object table("relation") +class Relation { + public: + Relation(){} + Relation(const std::string &uid, const std::string &pid): + _user_id(uid), _peer_id(pid){} + + std::string user_id() const { return _user_id; } + void user_id(std::string &uid) { _user_id = uid; } + + std::string peer_id() const { return _peer_id; } + void peer_id(std::string &uid) { _peer_id = uid; } + private: + friend class odb::access; + #pragma db id auto + unsigned long _id; + #pragma db type("varchar(64)") index + std::string _user_id; + #pragma db type("varchar(64)") + std::string _peer_id; +}; +//odb -d mysql --std c++11 --generate-query --generate-schema --profile boost/date-time person.hxx +} \ No newline at end of file diff --git a/odb/user.hxx b/odb/user.hxx new file mode 100644 index 0000000..671a546 --- /dev/null +++ b/odb/user.hxx @@ -0,0 +1,71 @@ +#pragma once + +#include +#include +#include +#include +#include + +typedef boost::posix_time::ptime ptime; +#pragma db object table("user") +class User +{ +public: + User(){} + //用户名--新增用户 -- 用户ID, 昵称,密码 + User(const std::string &uid, const std::string &nickname, const std::string &password): + _user_id(uid), _nickname(nickname), _password(password){} + //手机号--新增用户 -- 用户ID, 手机号, 随机昵称 + User(const std::string &uid, const std::string &phone): + _user_id(uid), _nickname(uid), _phone(phone){} + + void user_id(const std::string &val) { _user_id = val; } + std::string user_id() { return _user_id; } + + std::string nickname() { + if (_nickname) return *_nickname; + return std::string(); + } + void nickname(const std::string &val) { _nickname = val; } + + std::string description() { + if (!_description) return std::string(); + return *_description; + } + void description(const std::string &val) { _description = val; } + + std::string password() { + if (!_password) return std::string(); + return *_password; + } + void password(const std::string &val) { _password = val; } + + std::string phone() { + if (!_phone) return std::string(); + return *_phone; + } + void phone(const std::string &val) { _phone = val; } + + std::string avatar_id() { + if (!_avatar_id) return std::string(); + return *_avatar_id; + } + void avatar_id(const std::string &val) { _avatar_id = val; } + + +private: + friend class odb::access; + #pragma db id auto + unsigned long _id; + #pragma db index + std::string _user_id; + odb::nullable _nickname; //用户昵称,不一定存在 + #pragma db index + odb::nullable _description; //用户签名 不一定存在 + odb::nullable _password; //用户密码 不一定存在 + #pragma db index + odb::nullable _phone; //用户手机号 不一定存在 + odb::nullable _avatar_id; //用户头像文件id 不一定存在 +}; + +//odb -d mysql --std c++11 --generate-query --generate-schema --profile boost/date-time person.hxx \ No newline at end of file diff --git a/proto/base.proto b/proto/base.proto new file mode 100644 index 0000000..79f9b1d --- /dev/null +++ b/proto/base.proto @@ -0,0 +1,82 @@ +syntax = "proto3"; +package bite_im; +option cc_generic_services = true; + +//用户信息结构 +message UserInfo { + string user_id = 1;//用户ID + string nickname = 2;//昵称 + string description = 3;//个人签名/描述 + string phone = 4; //绑定手机号 + bytes avatar = 5;//头像照片,文件内容使用二进制 +} + +//聊天会话信息 +message ChatSessionInfo { + //群聊会话不需要设置,单聊会话设置为对方用户ID + optional string single_chat_friend_id = 1; + string chat_session_id = 2; //会话ID + string chat_session_name = 3;//会话名称git + //会话上一条消息,新建的会话没有最新消息 + optional MessageInfo prev_message = 4; + //会话头像 --群聊会话不需要,直接由前端固定渲染,单聊就是对方的头像 + optional bytes avatar = 5; +} + +//消息类型 +enum MessageType { + STRING = 0; + IMAGE = 1; + FILE = 2; + SPEECH = 3; +} +message StringMessageInfo { + string content = 1;//文字聊天内容 +} +message ImageMessageInfo { + //图片文件id,客户端发送的时候不用设置,由transmit服务器进行设置后交给storage的时候设置 + optional string file_id = 1; + //图片数据,在ES中存储消息的时候只要id不要文件数据, 服务端转发的时候需要原样转发 + optional bytes image_content = 2; +} +message FileMessageInfo { + optional string file_id = 1;//文件id,客户端发送的时候不用设置 + optional int64 file_size = 2;//文件大小 + optional string file_name = 3;//文件名称 + //文件数据,在ES中存储消息的时候只要id和元信息,不要文件数据, 服务端转发的时候也不需要填充 + optional bytes file_contents = 4; +} +message SpeechMessageInfo { + //语音文件id,客户端发送的时候不用设置 + optional string file_id = 1; + //文件数据,在ES中存储消息的时候只要id不要文件数据, 服务端转发的时候也不需要填充 + optional bytes file_contents = 2; +} +message MessageContent { + MessageType message_type = 1; //消息类型 + oneof msg_content { + StringMessageInfo string_message = 2;//文字消息 + FileMessageInfo file_message = 3;//文件消息 + SpeechMessageInfo speech_message = 4;//语音消息 + ImageMessageInfo image_message = 5;//图片消息 + }; +} +//消息结构 +message MessageInfo { + string message_id = 1;//消息ID + string chat_session_id = 2;//消息所属聊天会话ID + int64 timestamp = 3;//消息产生时间 + UserInfo sender = 4;//消息发送者信息 + MessageContent message = 5; +} + +message FileDownloadData { + string file_id = 1; + bytes file_content = 2; +} + +message FileUploadData { + string file_name = 1; //文件名称 + int64 file_size = 2; //文件大小 + bytes file_content = 3; //文件数据 +} \ No newline at end of file diff --git a/proto/file.proto b/proto/file.proto new file mode 100644 index 0000000..1a9a33d --- /dev/null +++ b/proto/file.proto @@ -0,0 +1,64 @@ +syntax = "proto3"; +package bite_im; +import "base.proto"; + +option cc_generic_services = true; + +message GetSingleFileReq { + string request_id = 1; + string file_id = 2; + optional string user_id = 3; + optional string session_id = 4; +} +message GetSingleFileRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + optional FileDownloadData file_data = 4; +} + +message GetMultiFileReq { + string request_id = 1; + optional string user_id = 2; + optional string session_id = 3; + repeated string file_id_list = 4; +} +message GetMultiFileRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + map file_data = 4;//文件ID与文件数据的映射map +} + +message PutSingleFileReq { + string request_id = 1; //请求ID,作为处理流程唯一标识 + optional string user_id = 2; + optional string session_id = 3; + FileUploadData file_data = 4; +} +message PutSingleFileRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + FileMessageInfo file_info = 4; //返回了文件组织的元信息 +} + +message PutMultiFileReq { + string request_id = 1; + optional string user_id = 2; + optional string session_id = 3; + repeated FileUploadData file_data = 4; +} +message PutMultiFileRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + repeated FileMessageInfo file_info = 4; +} + +service FileService { + rpc GetSingleFile(GetSingleFileReq) returns (GetSingleFileRsp); + rpc GetMultiFile(GetMultiFileReq) returns (GetMultiFileRsp); + rpc PutSingleFile(PutSingleFileReq) returns (PutSingleFileRsp); + rpc PutMultiFile(PutMultiFileReq) returns (PutMultiFileRsp); +} \ No newline at end of file diff --git a/proto/friend.proto b/proto/friend.proto new file mode 100644 index 0000000..edd3309 --- /dev/null +++ b/proto/friend.proto @@ -0,0 +1,155 @@ +syntax = "proto3"; +package bite_im; +import "base.proto"; + +option cc_generic_services = true; + +//-------------------------------------- +//好友列表获取 +message GetFriendListReq { + string request_id = 1; // 请求标识ID + optional string user_id = 2; // 当前请求的发起者用户ID + optional string session_id = 3; //登录会话ID--用于网关进行身份识别--其他子服务用不到 +} +message GetFriendListRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + repeated UserInfo friend_list = 4; //要返回的用户信息 +} + +//-------------------------------------- +//好友删除 +message FriendRemoveReq { + string request_id = 1; + optional string user_id = 2; //当前用户ID + optional string session_id = 3; + string peer_id = 4; //要删除的好友ID +} +message FriendRemoveRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; +} +//-------------------------------------- +//添加好友--发送好友申请 +message FriendAddReq { + string request_id = 1; + optional string session_id = 2; + optional string user_id = 3;//申请人id + string respondent_id = 4;//被申请人id +} +message FriendAddRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + string notify_event_id = 4;//通知事件id +} +//-------------------------------------- +//好友申请的处理 +message FriendAddProcessReq { + string request_id = 1; + string notify_event_id = 2;//通知事件id + bool agree = 3;//是否同意好友申请 + string apply_user_id = 4; //申请人的用户id + optional string session_id = 5; + optional string user_id = 6; // 被申请人 +} +// +++++++++++++++++++++++++++++++++ +message FriendAddProcessRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + // 同意后会创建会话,向网关返回会话信息,用于通知双方会话的建立,这个字段客户端不需要关注 + optional string new_session_id = 4; +} +//-------------------------------------- +//获取待处理的,申请自己好友的信息列表 +message GetPendingFriendEventListReq { + string request_id = 1; + optional string session_id = 2; + optional string user_id = 3; +} + +message FriendEvent { + optional string event_id = 1; + UserInfo sender = 3; +} +message GetPendingFriendEventListRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + repeated FriendEvent event = 4; +} + +//-------------------------------------- +//好友搜索 +message FriendSearchReq { + string request_id = 1; + string search_key = 2;//就是名称模糊匹配关键字 + optional string session_id = 3; + optional string user_id = 4; +} +message FriendSearchRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + repeated UserInfo user_info = 4; +} + +//-------------------------------------- +//会话列表获取 +message GetChatSessionListReq { + string request_id = 1; + optional string session_id = 2; + optional string user_id = 3; +} +message GetChatSessionListRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + repeated ChatSessionInfo chat_session_info_list = 4; +} +//-------------------------------------- +//创建会话 +message ChatSessionCreateReq { + string request_id = 1; + optional string session_id = 2; + optional string user_id = 3; + string chat_session_name = 4; + //需要注意的是,这个列表中也必须包含创建者自己的用户ID + repeated string member_id_list = 5; +} +message ChatSessionCreateRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + //这个字段属于后台之间的数据,给前端回复的时候不需要这个字段,会话信息通过通知进行发送 + optional ChatSessionInfo chat_session_info = 4; +} +//-------------------------------------- +//获取会话成员列表 +message GetChatSessionMemberReq { + string request_id = 1; + optional string session_id = 2; + optional string user_id = 3; + string chat_session_id = 4; +} +message GetChatSessionMemberRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + repeated UserInfo member_info_list = 4; +} + +service FriendService { + rpc GetFriendList(GetFriendListReq) returns (GetFriendListRsp); + rpc FriendRemove(FriendRemoveReq) returns (FriendRemoveRsp); + rpc FriendAdd(FriendAddReq) returns (FriendAddRsp); + rpc FriendAddProcess(FriendAddProcessReq) returns (FriendAddProcessRsp); + rpc FriendSearch(FriendSearchReq) returns (FriendSearchRsp); + rpc GetChatSessionList(GetChatSessionListReq) returns (GetChatSessionListRsp); + rpc ChatSessionCreate(ChatSessionCreateReq) returns (ChatSessionCreateRsp); + rpc GetChatSessionMember(GetChatSessionMemberReq) returns (GetChatSessionMemberRsp); + rpc GetPendingFriendEventList(GetPendingFriendEventListReq) returns (GetPendingFriendEventListRsp); +} \ No newline at end of file diff --git a/proto/gateway.proto b/proto/gateway.proto new file mode 100644 index 0000000..4fa312f --- /dev/null +++ b/proto/gateway.proto @@ -0,0 +1,53 @@ +syntax = "proto3"; +package bite_im; +option cc_generic_services = true; + +message ClientAuthenticationReq { + string request_id = 1; + string session_id = 2; // 用于向服务器表明当前长连接客户端的身份 +} + +//在客户端与网关服务器的通信中,使用HTTP协议进行通信 +// 通信时采用POST请求作为请求方法 +// 通信时,正文采用protobuf作为正文协议格式,具体内容字段以前边各个文件中定义的字段格式为准 +/* 以下是HTTP请求的功能与接口路径对应关系: + SERVICE HTTP PATH: + { + 获取随机验证码 /service/user/get_random_verify_code + 获取短信验证码 /service/user/get_phone_verify_code + 用户名密码注册 /service/user/username_register + 用户名密码登录 /service/user/username_login + 手机号码注册 /service/user/phone_register + 手机号码登录 /service/user/phone_login + 获取个人信息 /service/user/get_user_info + 修改头像 /service/user/set_avatar + 修改昵称 /service/user/set_nickname + 修改签名 /service/user/set_description + 修改绑定手机 /service/user/set_phone + + 获取好友列表 /service/friend/get_friend_list + 获取好友信息 /service/friend/get_friend_info + 发送好友申请 /service/friend/add_friend_apply + 好友申请处理 /service/friend/add_friend_process + 删除好友 /service/friend/remove_friend + 搜索用户 /service/friend/search_friend + 获取指定用户的消息会话列表 /service/friend/get_chat_session_list + 创建消息会话 /service/friend/create_chat_session + 获取消息会话成员列表 /service/friend/get_chat_session_member + 获取待处理好友申请事件列表 /service/friend/get_pending_friend_events + + 获取历史消息/离线消息列表 /service/message_storage/get_history + 获取最近N条消息列表 /service/message_storage/get_recent + 搜索历史消息 /service/message_storage/search_history + + 发送消息 /service/message_transmit/new_message + + 获取单个文件数据 /service/file/get_single_file + 获取多个文件数据 /service/file/get_multi_file + 发送单个文件 /service/file/put_single_file + 发送多个文件 /service/file/put_multi_file + + 语音转文字 /service/speech/recognition + } + +*/ \ No newline at end of file diff --git a/proto/message.proto b/proto/message.proto new file mode 100644 index 0000000..4d13a7d --- /dev/null +++ b/proto/message.proto @@ -0,0 +1,55 @@ +syntax = "proto3"; +package bite_im; +import "base.proto"; + +option cc_generic_services = true; + +message GetHistoryMsgReq { + string request_id = 1; + string chat_session_id = 2; + int64 start_time = 3; + int64 over_time = 4; + optional string user_id = 5; + optional string session_id = 6; +} +message GetHistoryMsgRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + repeated MessageInfo msg_list = 4; +} + +message GetRecentMsgReq { + string request_id = 1; + string chat_session_id = 2; + int64 msg_count = 3; + optional int64 cur_time = 4;//用于扩展获取指定时间前的n条消息 + optional string user_id = 5; + optional string session_id = 6; +} +message GetRecentMsgRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + repeated MessageInfo msg_list = 4; +} + +message MsgSearchReq { + string request_id = 1; + optional string user_id = 2; + optional string session_id = 3; + string chat_session_id = 4; + string search_key = 5; +} +message MsgSearchRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + repeated MessageInfo msg_list = 4; +} + +service MsgStorageService { + rpc GetHistoryMsg(GetHistoryMsgReq) returns (GetHistoryMsgRsp); + rpc GetRecentMsg(GetRecentMsgReq) returns (GetRecentMsgRsp); + rpc MsgSearch(MsgSearchReq) returns (MsgSearchRsp); +} \ No newline at end of file diff --git a/proto/notify.proto b/proto/notify.proto new file mode 100644 index 0000000..eb88843 --- /dev/null +++ b/proto/notify.proto @@ -0,0 +1,39 @@ +syntax = "proto3"; +package bite_im; +import "base.proto"; +option cc_generic_services = true; + +enum NotifyType { + FRIEND_ADD_APPLY_NOTIFY = 0; + FRIEND_ADD_PROCESS_NOTIFY = 1; + CHAT_SESSION_CREATE_NOTIFY = 2; + CHAT_MESSAGE_NOTIFY = 3; + FRIEND_REMOVE_NOTIFY = 4; +} +message NotifyFriendAddApply { + UserInfo user_info = 1; //申请人信息 +} +message NotifyFriendAddProcess { + bool agree = 1; + UserInfo user_info = 2; //处理人信息 +} +message NotifyFriendRemove { + string user_id = 1; //删除自己的用户 ID +} +message NotifyNewChatSession { + ChatSessionInfo chat_session_info = 1; //新建会话信息 +} +message NotifyNewMessage { + MessageInfo message_info = 1; //新消息 +} +message NotifyMessage { + optional string notify_event_id = 1;//通知事件操作 id(有则填无则忽略) + NotifyType notify_type = 2;//通知事件类型 + oneof notify_remarks { //事件备注信息 + NotifyFriendAddApply friend_add_apply = 3; + NotifyFriendAddProcess friend_process_result = 4; + NotifyFriendRemove friend_remove = 7; + NotifyNewChatSession new_chat_session_info = 5;//会话信息 + NotifyNewMessage new_message_info = 6;//消息信息 + } +} \ No newline at end of file diff --git a/proto/speech.proto b/proto/speech.proto new file mode 100644 index 0000000..df26bbd --- /dev/null +++ b/proto/speech.proto @@ -0,0 +1,23 @@ +syntax = "proto3"; +package bite_im; + +option cc_generic_services = true; + +message SpeechRecognitionReq { + string request_id = 1; //请求ID + bytes speech_content = 2; //语音数据 + optional string user_id = 3; //用户ID + optional string session_id = 4; //登录会话ID -- 网关进行身份鉴权 +} + +message SpeechRecognitionRsp { + string request_id = 1; //请求ID + bool success = 2; //请求处理结果标志 + optional string errmsg = 3; //失败原因 + optional string recognition_result = 4; //识别后的文字数据 +} + +//语音识别Rpc服务及接口的定义 +service SpeechService { + rpc SpeechRecognition(SpeechRecognitionReq) returns (SpeechRecognitionRsp); +} \ No newline at end of file diff --git a/proto/transmite.proto b/proto/transmite.proto new file mode 100644 index 0000000..2546d3b --- /dev/null +++ b/proto/transmite.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; +package bite_im; +import "base.proto"; + +option cc_generic_services = true; + +//这个用于和网关进行通信 +message NewMessageReq { + string request_id = 1; //请求ID -- 全链路唯一标识 + optional string user_id = 2; + optional string session_id = 3;//客户端身份识别信息 -- 这就是消息发送者 + string chat_session_id = 4; //聊天会话ID -- 标识了当前消息属于哪个会话,应该转发给谁 + MessageContent message = 5; // 消息内容--消息类型+内容 +} +message NewMessageRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; +} + +//这个用于内部的通信,生成完整的消息信息,并获取消息的转发人员列表 +message GetTransmitTargetRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + MessageInfo message = 4; // 组织好的消息结构 -- + repeated string target_id_list = 5; //消息的转发目标列表 +} + +service MsgTransmitService { + rpc GetTransmitTarget(NewMessageReq) returns (GetTransmitTargetRsp); +} \ No newline at end of file diff --git a/proto/user.proto b/proto/user.proto new file mode 100644 index 0000000..6a2548a --- /dev/null +++ b/proto/user.proto @@ -0,0 +1,166 @@ +syntax = "proto3"; +package bite_im; +import "base.proto"; +option cc_generic_services = true; + +//---------------------------- +//用户名注册 +message UserRegisterReq { + string request_id = 1; + string nickname = 2; + string password = 3; + optional string verify_code_id = 4; //目前客户端实现了本地验证,该字段没用了 + optional string verify_code = 5;//目前客户端实现了本地验证,该字段没用了 +} +message UserRegisterRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; +} +//---------------------------- +//用户名登录 +message UserLoginReq { + string request_id = 1; + string nickname = 2; + string password = 3; + optional string verify_code_id = 4; + optional string verify_code = 5; +} +message UserLoginRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + string login_session_id = 4; +} +//---------------------------- +//手机号验证码获取 +message PhoneVerifyCodeReq { + string request_id = 1; + string phone_number = 2; +} +message PhoneVerifyCodeRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + string verify_code_id = 4; +} +//---------------------------- +//手机号注册 +message PhoneRegisterReq { + string request_id = 1; + string phone_number = 2; + string verify_code_id = 3; + string verify_code = 4; +} +message PhoneRegisterRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; +} +//---------------------------- +//手机号登录 +message PhoneLoginReq { + string request_id = 1; + string phone_number = 2; + string verify_code_id = 3; + string verify_code = 4; +} +message PhoneLoginRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + string login_session_id = 4; +} +//个人信息获取-这个只用于获取当前登录用户的信息 +// 客户端传递的时候只需要填充session_id即可 +//其他个人/好友信息的获取在好友操作中完成 +message GetUserInfoReq { + string request_id = 1; + optional string user_id = 2; // 这个字段是网关进行身份鉴权之后填入的字段 + optional string session_id = 3; // 进行客户端身份识别的关键字段 +} +message GetUserInfoRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + UserInfo user_info = 4; +} +//内部接口 +message GetMultiUserInfoReq { + string request_id = 1; + repeated string users_id = 2; +} +message GetMultiUserInfoRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; + map users_info = 4; +} +//---------------------------- +//用户头像修改 +message SetUserAvatarReq { + string request_id = 1; + optional string user_id = 2; + optional string session_id = 3; + bytes avatar = 4; +} +message SetUserAvatarRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; +} +//---------------------------- +//用户昵称修改 +message SetUserNicknameReq { + string request_id = 1; + optional string user_id = 2; + optional string session_id = 3; + string nickname = 4; +} +message SetUserNicknameRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; +} +//---------------------------- +//用户签名修改 +message SetUserDescriptionReq { + string request_id = 1; + optional string user_id = 2; + optional string session_id = 3; + string description = 4; +} +message SetUserDescriptionRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; +} +//---------------------------- +//用户手机修改 +message SetUserPhoneNumberReq { + string request_id = 1; + optional string user_id = 2; + optional string session_id = 3; + string phone_number = 4; + string phone_verify_code_id = 5; + string phone_verify_code = 6; +} +message SetUserPhoneNumberRsp { + string request_id = 1; + bool success = 2; + string errmsg = 3; +} + +service UserService { + rpc UserRegister(UserRegisterReq) returns (UserRegisterRsp); + rpc UserLogin(UserLoginReq) returns (UserLoginRsp); + rpc GetPhoneVerifyCode(PhoneVerifyCodeReq) returns (PhoneVerifyCodeRsp); + rpc PhoneRegister(PhoneRegisterReq) returns (PhoneRegisterRsp); + rpc PhoneLogin(PhoneLoginReq) returns (PhoneLoginRsp); + rpc GetUserInfo(GetUserInfoReq) returns (GetUserInfoRsp); + rpc GetMultiUserInfo(GetMultiUserInfoReq) returns (GetMultiUserInfoRsp); + rpc SetUserAvatar(SetUserAvatarReq) returns (SetUserAvatarRsp); + rpc SetUserNickname(SetUserNicknameReq) returns (SetUserNicknameRsp); + rpc SetUserDescription(SetUserDescriptionReq) returns (SetUserDescriptionRsp); + rpc SetUserPhoneNumber(SetUserPhoneNumberReq) returns (SetUserPhoneNumberRsp); +} \ No newline at end of file diff --git a/speech/CMakeLists.txt b/speech/CMakeLists.txt new file mode 100644 index 0000000..27d5f56 --- /dev/null +++ b/speech/CMakeLists.txt @@ -0,0 +1,54 @@ +# 1. 添加cmake版本说明 +cmake_minimum_required(VERSION 3.1.3) +# 2. 声明工程名称 +project(speech_server) + +set(target "speech_server") +set(test_client "speech_client") + +# 3. 检测并生成ODB框架代码 +# 1. 添加所需的proto映射代码文件名称 +set(proto_path ${CMAKE_CURRENT_SOURCE_DIR}/../proto) +set(proto_files speech.proto) +# 2. 检测框架代码文件是否已经生成 +set(proto_hxx "") +set(proto_cxx "") +set(proto_srcs "") +foreach(proto_file ${proto_files}) +# 3. 如果没有生成,则预定义生成指令 -- 用于在构建项目之间先生成框架代码 + string(REPLACE ".proto" ".pb.cc" proto_cc ${proto_file}) + string(REPLACE ".proto" ".pb.h" proto_hh ${proto_file}) + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}${proto_cc}) + add_custom_command( + PRE_BUILD + COMMAND protoc + ARGS --cpp_out=${CMAKE_CURRENT_BINARY_DIR} -I ${proto_path} --experimental_allow_proto3_optional ${proto_path}/${proto_file} + DEPENDS ${proto_path}/${proto_file} + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + COMMENT "生成Protobuf框架代码文件:" ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + ) + endif() + list(APPEND proto_srcs ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc}) +endforeach() + +# 4. 获取源码目录下的所有源码文件 +set(src_files "") +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/source src_files) +# 5. 声明目标及依赖 +add_executable(${target} ${src_files} ${proto_srcs}) +# 7. 设置需要连接的库 +target_link_libraries(${target} -lgflags -lspdlog -lfmt -lbrpc -lssl -lcrypto -lprotobuf -lleveldb -letcd-cpp-api -lcpprest -lcurl /usr/lib/x86_64-linux-gnu/libjsoncpp.so.19) + + +set(test_files "") +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/test test_files) +add_executable(${test_client} ${test_files} ${proto_srcs}) +target_link_libraries(${test_client} -lgflags -lspdlog -lfmt -lbrpc -lssl -lcrypto -lprotobuf -lleveldb -letcd-cpp-api -lcpprest -lcurl /usr/lib/x86_64-linux-gnu/libjsoncpp.so.19) + +# 6. 设置头文件默认搜索路径 +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../common) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third/include) + +#8. 设置安装路径 +INSTALL(TARGETS ${target} ${test_client} RUNTIME DESTINATION bin) \ No newline at end of file diff --git a/speech/dockerfile b/speech/dockerfile new file mode 100644 index 0000000..dd9a1a2 --- /dev/null +++ b/speech/dockerfile @@ -0,0 +1,16 @@ +# 声明基础经镜像来源 +FROM debian:12 + +# 声明工作目录 +WORKDIR /im +RUN mkdir -p /im/logs &&\ + mkdir -p /im/data &&\ + mkdir -p /im/conf &&\ + mkdir -p /im/bin + +# 将可执行程序依赖,拷贝进镜像 +COPY ./build/speech_server /im/bin/ +# 将可执行程序文件,拷贝进镜像 +COPY ./depends /lib/x86_64-linux-gnu/ +# 设置容器启动的默认操作 ---运行程序 +CMD /im/bin/speech_server -flagfile=/im/conf/speech_server.conf \ No newline at end of file diff --git a/speech/source/speech_server.cc b/speech/source/speech_server.cc new file mode 100644 index 0000000..bc2ac05 --- /dev/null +++ b/speech/source/speech_server.cc @@ -0,0 +1,33 @@ +#include "speech_server.hpp" +//语音识别子服务 + +DEFINE_bool(run_mode, false, "程序的运行模式, false-调试;true-发布"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志的输出等级"); + +DEFINE_string(registry_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(instance_name, "/speech_service/instance", "当前实例名称"); +DEFINE_string(access_host, "127.0.0.1:10001", "当前实例的外部访问地址"); + +DEFINE_int32(listen_port, 10001, "Rpc服务器监听端口"); +DEFINE_int32(rpc_timeout, -1, "Rpc调用超时时间"); +DEFINE_int32(rpc_threads, 1, "Rpc的IO线程数量"); + +DEFINE_string(app_id, "118805148", "语音平台应用ID"); +DEFINE_string(api_key, "tRBBbRWdTOjHgr8xZX0s4Z2d", "语音平台API密钥"); +DEFINE_string(secret_key, "H2pyXuWi04uKEKK0T8jrTYo7Pj4UUUpC", "语音平台加密密钥"); + +int main(int argc, char* argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + bite_im::SpeechServerBuilder ssb; + ssb.make_asr_object(FLAGS_app_id, FLAGS_api_key, FLAGS_secret_key); + ssb.make_rpc_server(FLAGS_listen_port, FLAGS_rpc_timeout, FLAGS_rpc_threads); + ssb.make_reg_object(FLAGS_registry_host, FLAGS_base_service + FLAGS_instance_name, FLAGS_access_host); + auto server = ssb.build(); + server->start(); + return 0; +} \ No newline at end of file diff --git a/speech/source/speech_server.hpp b/speech/source/speech_server.hpp new file mode 100644 index 0000000..2bb35de --- /dev/null +++ b/speech/source/speech_server.hpp @@ -0,0 +1,123 @@ +#include +#include + +#include "asr.hpp" //语音识别模块 +#include "etcd.hpp" //服务注册模块 +#include "logger.hpp" //日志模块 +#include "speech.pb.h" //protobuf框架代码 + +using namespace bite_im; + +namespace bite_im{ +class SpeechServiceImpl : public bite_im::SpeechService { + public: + SpeechServiceImpl(const ASRClient::ptr &asr_client): + _asr_client(asr_client){} + ~SpeechServiceImpl(){} + void SpeechRecognition(google::protobuf::RpcController* controller, + const ::bite_im::SpeechRecognitionReq* request, + ::bite_im::SpeechRecognitionRsp* response, + ::google::protobuf::Closure* done) { + LOG_DEBUG("收到语音转文字请求!"); + brpc::ClosureGuard rpc_guard(done); + //1. 取出请求中的语音数据 + //2. 调用语音sdk模块进行语音识别,得到响应 + std::string err; + std::string res = _asr_client->recognize(request->speech_content(), err); + if (res.empty()) { + LOG_ERROR("{} 语音识别失败!", request->request_id()); + response->set_request_id(request->request_id()); + response->set_success(false); + response->set_errmsg("语音识别失败:" + err); + return; + } + //3. 组织响应 + response->set_request_id(request->request_id()); + response->set_success(true); + response->set_recognition_result(res); + } + private: + ASRClient::ptr _asr_client; +}; + +class SpeechServer { + public: + using ptr = std::shared_ptr; + SpeechServer(const ASRClient::ptr asr_client, + const Registry::ptr ®_client, + const std::shared_ptr &server): + _asr_client(asr_client), + _reg_client(reg_client), + _rpc_server(server){} + ~SpeechServer(){} + //搭建RPC服务器,并启动服务器 + void start() { + _rpc_server->RunUntilAskedToQuit(); + } + private: + ASRClient::ptr _asr_client; + Registry::ptr _reg_client; + std::shared_ptr _rpc_server; +}; + +class SpeechServerBuilder { + public: + //构造语音识别客户端对象 + void make_asr_object(const std::string &app_id, + const std::string &api_key, + const std::string &secret_key) { + _asr_client = std::make_shared(app_id, api_key, secret_key); + } + //用于构造服务注册客户端对象 + void make_reg_object(const std::string ®_host, + const std::string &service_name, + const std::string &access_host) { + _reg_client = std::make_shared(reg_host); + _reg_client->registry(service_name, access_host); + } + //构造RPC服务器对象 + void make_rpc_server(uint16_t port, int32_t timeout, uint8_t num_threads) { + if (!_asr_client) { + LOG_ERROR("还未初始化语音识别模块!"); + abort(); + } + _rpc_server = std::make_shared(); + SpeechServiceImpl *speech_service = new SpeechServiceImpl(_asr_client); + int ret = _rpc_server->AddService(speech_service, + brpc::ServiceOwnership::SERVER_OWNS_SERVICE); + if (ret == -1) { + LOG_ERROR("添加Rpc服务失败!"); + abort(); + } + brpc::ServerOptions options; + options.idle_timeout_sec = timeout; + options.num_threads = num_threads; + ret = _rpc_server->Start(port, &options); + if (ret == -1) { + LOG_ERROR("服务启动失败!"); + abort(); + } + } + SpeechServer::ptr build() { + if (!_asr_client) { + LOG_ERROR("还未初始化语音识别模块!"); + abort(); + } + if (!_reg_client) { + LOG_ERROR("还未初始化服务注册模块!"); + abort(); + } + if (!_rpc_server) { + LOG_ERROR("还未初始化RPC服务器模块!"); + abort(); + } + SpeechServer::ptr server = std::make_shared( + _asr_client, _reg_client, _rpc_server); + return server; + } + private: + ASRClient::ptr _asr_client; + Registry::ptr _reg_client; + std::shared_ptr _rpc_server; +}; +} \ No newline at end of file diff --git a/speech/speech_server.conf b/speech/speech_server.conf new file mode 100644 index 0000000..dd259c8 --- /dev/null +++ b/speech/speech_server.conf @@ -0,0 +1,13 @@ +-run_mode=true +-log_file=/im/logs/speech.log +-log_level=0 +-registry_host=http://10.0.0.235:2379 +-instance_name=/speech_service/instance +-access_host=10.0.0.235:10001 +-base_service=/service +-listen_port=10001 +-rpc_timeout=-1 +-rpc_threads=1 +-app_id=60694095 +-api_key=PWn6zlsxym8VwpBW8Or4PPGe +-secret_key=Bl0mn74iyAkr3FzCo5TZV7lBq7NYoms9 \ No newline at end of file diff --git a/speech/test/speech_client.cc b/speech/test/speech_client.cc new file mode 100644 index 0000000..9f18109 --- /dev/null +++ b/speech/test/speech_client.cc @@ -0,0 +1,73 @@ + +//speech_server的测试客户端实现 +//进行服务的发现,发现speech_server服务器节点地址信息并实例化通信信道 +//读取语音文件数据 +//发起语音识别rpc调用 + +#include +#include "etcd.hpp" +#include "channel.hpp" +#include +#include +#include "aip-cpp-sdk/speech.h" +#include "aip-cpp-sdk/base/utils.h" +#include "speech.pb.h" + + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +DEFINE_string(etcd_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(speech_service, "/service/speech_service", "服务监控根目录"); + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + + //1. 先构造Rpc信道管理对象 + auto sm = std::make_shared(); + sm->declared(FLAGS_speech_service); + auto put_cb = std::bind(&bite_im::ServiceManager::onServiceOnline, sm.get(), std::placeholders::_1, std::placeholders::_2); + auto del_cb = std::bind(&bite_im::ServiceManager::onServiceOffline, sm.get(), std::placeholders::_1, std::placeholders::_2); + //2. 构造服务发现对象 + bite_im::Discovery::ptr dclient = std::make_shared(FLAGS_etcd_host, FLAGS_base_service, put_cb, del_cb); + + //3. 通过Rpc信道管理对象,获取提供Echo服务的信道 + auto channel = sm->choose(FLAGS_speech_service); + if (!channel) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + return -1; + } + //读取语音文件数据 + std::string file_content; + aip::get_file_content("16k.pcm", &file_content); + std::cout << file_content.size() << std::endl; + + //4. 发起EchoRpc调用 + bite_im::SpeechService_Stub stub(channel.get()); + bite_im::SpeechRecognitionReq req; + req.set_speech_content(file_content); + req.set_request_id("111111"); + + brpc::Controller *cntl = new brpc::Controller(); + bite_im::SpeechRecognitionRsp *rsp = new bite_im::SpeechRecognitionRsp(); + stub.SpeechRecognition(cntl, &req, rsp, nullptr); + if (cntl->Failed() == true) { + std::cout << "Rpc调用失败:" << cntl->ErrorText() << std::endl; + delete cntl; + delete rsp; + std::this_thread::sleep_for(std::chrono::seconds(1)); + return -1; + } + if (rsp->success() == false) { + std::cout << rsp->errmsg() << std::endl; + return -1; + } + std::cout << "收到响应: " << rsp->request_id() << std::endl; + std::cout << "收到响应: " << rsp->recognition_result() << std::endl; + return 0; +} \ No newline at end of file diff --git a/third/include/aip-cpp-sdk/README.md b/third/include/aip-cpp-sdk/README.md new file mode 100644 index 0000000..3167564 --- /dev/null +++ b/third/include/aip-cpp-sdk/README.md @@ -0,0 +1,35 @@ +# 安装百度AI开放平台 C++ SDK + +**C++ SDK目录结构** + + ├── base + │ ├── base.h // 授权相关类 + │ ├── base64.h // base64加密类 + │ ├── http.h // http请求类 + │ └── utils.h // 工具类 + ├── face.h // 人脸识别交互类 + ├── image_censor.h // 图像审核交互类 + ├── image_classify.h // 图像识别交互类 + ├── image_search.h // 图像搜索交互类 + ├── kg.h // 人脸识别交互类 + ├── nlp.h // 人脸识别交互类 + ├── ocr.h // 人脸识别交互类 + └── speech.h // 语音识别交互类 + +**支持 C++ 11+** + +**直接使用开发包步骤如下:** + +1.在[官方网站](http://ai.baidu.com/sdk)下载C++ SDK压缩包。 + +2.将下载的`aip-cpp-sdk-version.zip`解压, 其中文件为包含实现代码的头文件。 + +3.安装依赖库curl(需要支持ssl) openssl jsoncpp(>1.6.2版本,0.x版本将不被支持)。 + +4.编译工程时添加 C++11 支持 (gcc/clang 添加编译参数 -std=c++11), 添加第三方库链接参数 lcurl, lcrypto, ljsoncpp。 + +5.在源码中include 您需要使用的交互类头文件(face.h image_censor.h image_classify.h kg.h nlp.h ocr.h speech.h等),引入压缩包中的头文即可使用aip命名空间下的类和方法。 + +# 详细使用文档 + +参考[百度AI开放平台官方文档](http://ai.baidu.com/docs) diff --git a/third/include/aip-cpp-sdk/base/base.h b/third/include/aip-cpp-sdk/base/base.h new file mode 100644 index 0000000..404e0e0 --- /dev/null +++ b/third/include/aip-cpp-sdk/base/base.h @@ -0,0 +1,311 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ +#ifndef __AIP_BASE_H__ +#define __AIP_BASE_H__ + +#include +#include +#include "http.h" +#include "json/json.h" +#include "base64.h" +#include "curl/curl.h" +#include "utils.h" + +namespace aip { + + static const char* AIP_SDK_VERSION = "0.3.3"; + static const char* CURL_ERROR_CODE = "curl_error_code"; + static const std::string ACCESS_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"; + static const std::map null; + static const Json::Value json_null; + + class AipBase + { + private: + std::string _app_id; + int _expired_time; + bool _is_bce; + bool _has_decide_type; + std::string _scope; + + protected: + std::string getAccessToken() + { + time_t now = time(NULL); + if (!access_token.empty()) + { + return this->access_token; + } + + if (now < this->_expired_time - 60 * 60 * 24) + { + return this->access_token; + } + + std::string response; + std::map params; + + params["grant_type"] = "client_credentials"; + params["client_id"] = this->ak; + params["client_secret"] = this->sk; + int status_code = this->client.get( + ACCESS_TOKEN_URL, + ¶ms, + nullptr, + &response + ); + Json::Value obj; + if (status_code != CURLcode::CURLE_OK) { + obj[CURL_ERROR_CODE] = status_code; + return obj.toStyledString(); + } + + std::string error; + std::unique_ptr reader(crbuilder.newCharReader()); + reader->parse(response.data(), response.data() + response.size(), &obj, &error); + this->access_token = obj["access_token"].asString(); + this->_expired_time = obj["expires_in"].asInt() + (int) now; + this->_scope = obj["scope"].asString(); + return this->access_token; + } + + void merge_json(Json::Value& data, const Json::Value& options) { + Json::Value::Members mem = options.getMemberNames(); + for (auto & iter : mem) { + data[iter.c_str()] = options[iter]; + } + } + + public: + std::string ak; + std::string sk; + HttpClient client; + Json::CharReaderBuilder crbuilder; + std::string access_token; + AipBase(const std::string & app_id, const std::string & ak, const std::string & sk): + _app_id(app_id), + _is_bce(false), + _has_decide_type(false), + ak(ak), + sk(sk) + { + if (_app_id == "") + { + } + } + + void setConnectionTimeoutInMillis(int connect_timeout) + { + this->client.setConnectTimeout(connect_timeout); + } + + void setSocketTimeoutInMillis(int socket_timeout) + { + this->client.setSocketTimeout(socket_timeout); + } + + void setDebug(bool debug) + { + this->client.setDebug(debug); + } + + std::string getAk() { + return ak; + } + + Json::Value request( + std::string url, + std::map const & params, + std::string const & data, + std::map const & headers) + { + std::string response; + Json::Value obj; + std::string body; + auto headers_for_sign = headers; + + auto temp_params = params; + + temp_params["charset"] = "UTF-8"; + + this->prepare_request(url, temp_params, headers_for_sign); + + int status_code = this->client.post(url, &temp_params, data, &headers_for_sign, &response); + + if (status_code != CURLcode::CURLE_OK) { + obj[CURL_ERROR_CODE] = status_code; + return obj; + } + + std::string error; + std::unique_ptr reader(crbuilder.newCharReader()); + reader->parse(response.data(), response.data() + response.size(), &obj, &error); + + return obj; + } + + Json::Value request( + std::string url, + std::map const & params, + std::map const & data, + std::map const & headers) + { + std::string response; + Json::Value obj; + + auto headers_for_sign = headers; + auto temp_params = params; + + this->prepare_request(url, temp_params, headers_for_sign); + + int status_code = this->client.post(url, &temp_params, data, &headers_for_sign, &response); + + if (status_code != CURLcode::CURLE_OK) { + obj[CURL_ERROR_CODE] = status_code; + return obj; + } + + std::string error; + std::unique_ptr reader(crbuilder.newCharReader()); + reader->parse(response.data(), response.data() + response.size(), &obj, &error); + + return obj; + } + + void prepare_request(std::string url, + std::map & params, + std::map & headers) + { + + params["aipSdk"] = "C"; + params["aipSdkVersion"] = AIP_SDK_VERSION; + + if (_has_decide_type) { + if (_is_bce) { + std::string method = "POST"; + sign(method, url, params, headers, ak, sk); + } else { + params["access_token"] = this->getAccessToken(); + } + + return; + } + + if (getAccessToken() == "") { + _is_bce = true; + + } else { + + const char * t = std::strstr(this->_scope.c_str(), "brain_all_scope"); + + if (t == NULL) + { + _is_bce = true; + } + } + + _has_decide_type = true; + prepare_request(url, params, headers); + } + + + Json::Value requestjson( + std::string url, + Json::Value & data, + std::map & params, + std::map const & headers) + { + + std::string response; + Json::Value obj; + auto headers_for_sign = headers; + auto temp_params = params; + this->prepare_request(url, temp_params, headers_for_sign); + int status_code = this->client.post(url, nullptr, data, nullptr, &response); + if (status_code != CURLcode::CURLE_OK) { + obj[aip::CURL_ERROR_CODE] = status_code; + return obj; + } + + std::string error; + std::unique_ptr reader(crbuilder.newCharReader()); + reader->parse(response.data(), response.data() + response.size(), &obj, &error); + + return obj; + } + +// Json::Value request_com( +// std::string const & url, +// Json::Value & data) +// { +// std::string response; +// Json::Value obj; +// int status_code = this->client.post(url, nullptr, data, nullptr, &response); +// +// if (status_code != CURLcode::CURLE_OK) { +// obj[aip::CURL_ERROR_CODE] = status_code; +// return obj; +// } +// std::string error; +// std::unique_ptr reader(crbuilder.newCharReader()); +// reader->parse(response.data(), response.data() + response.size(), &obj, &error); +// +// return obj; +// } + + Json::Value request_com( + std::string const & url, + Json::Value & data, + std::map* headers = nullptr, + std::map* params = nullptr) + { + std::string response; + Json::Value obj; + + std::map headers_for_sign; + if (headers != nullptr) { + headers_for_sign = *headers; + } + + std::map temp_params; + if (params != nullptr) { + temp_params = *params; + } + this->prepare_request(url, temp_params, headers_for_sign); + + int status_code = CURLcode::CURLE_OK; + if (headers == nullptr || headers->find("Content-Type") == headers->end() + || (*headers)["Content-Type"] == "application/json") { + status_code = this->client.post(url, &temp_params, data, &headers_for_sign, &response); + } else if ((*headers)["Content-Type"] == "application/x-www-form-urlencoded") { + status_code = this->client.post_form(url, &temp_params, data, &headers_for_sign, &response); + } + + if (status_code != CURLcode::CURLE_OK) { + obj[aip::CURL_ERROR_CODE] = status_code; + return obj; + } + std::string error; + std::unique_ptr reader(crbuilder.newCharReader()); + reader->parse(response.data(), response.data() + response.size(), &obj, &error); + + return obj; + } + + + }; + +} +#endif diff --git a/third/include/aip-cpp-sdk/base/base64.h b/third/include/aip-cpp-sdk/base/base64.h new file mode 100644 index 0000000..e0ab6ff --- /dev/null +++ b/third/include/aip-cpp-sdk/base/base64.h @@ -0,0 +1,130 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ +#ifndef __AIP_BASE64_H__ +#define __AIP_BASE64_H__ + +#include +#include + +namespace aip { + + static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + + + static inline bool is_base64(const char c) + { + return (isalnum(c) || (c == '+') || (c == '/')); + } + + std::string base64_encode(const char * bytes_to_encode, unsigned int in_len) + { + std::string ret; + int i = 0; + int j = 0; + unsigned char char_array_3[3]; + unsigned char char_array_4[4]; + + while (in_len--) + { + char_array_3[i++] = *(bytes_to_encode++); + if(i == 3) + { + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + + for(i = 0; (i <4) ; i++) + { + ret += base64_chars[char_array_4[i]]; + } + i = 0; + } + } + + if(i) + { + for(j = i; j < 3; j++) + { + char_array_3[j] = '\0'; + } + + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + + for(j = 0; (j < i + 1); j++) + { + ret += base64_chars[char_array_4[j]]; + } + + while((i++ < 3)) + { + ret += '='; + } + + } + + return ret; + } + + std::string base64_decode(std::string const & encoded_string) + { + int in_len = (int) encoded_string.size(); + int i = 0; + int j = 0; + int in_ = 0; + unsigned char char_array_4[4], char_array_3[3]; + std::string ret; + + while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i ==4) { + for (i = 0; i <4; i++) + char_array_4[i] = base64_chars.find(char_array_4[i]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) + ret += char_array_3[i]; + i = 0; + } + } + + if (i) { + for (j = i; j <4; j++) + char_array_4[j] = 0; + + for (j = 0; j <4; j++) + char_array_4[j] = base64_chars.find(char_array_4[j]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; (j < i - 1); j++) ret += char_array_3[j]; + } + + return ret; + } + +} +#endif diff --git a/third/include/aip-cpp-sdk/base/http.h b/third/include/aip-cpp-sdk/base/http.h new file mode 100644 index 0000000..a4a4ab1 --- /dev/null +++ b/third/include/aip-cpp-sdk/base/http.h @@ -0,0 +1,306 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ +#ifndef __AIP_HTTP_H__ +#define __AIP_HTTP_H__ + +#include "curl/curl.h" + +#include +#include +#include +#include + +namespace aip { + + inline size_t onWriteData(void * buffer, size_t size, size_t nmemb, void * userp) + { + std::string * str = dynamic_cast((std::string *)userp); + str->append((char *)buffer, size * nmemb); + return nmemb; + } + + class HttpClient + { + private: + bool debug = false; + int connect_timeout = 10000; + int socket_timeout = 10000; + + void makeUrlencodedForm(std::map const & params, std::string * content) const + { + content->clear(); + std::map::const_iterator it; + for(it=params.begin(); it!=params.end(); it++) + { + char * key = curl_escape(it->first.c_str(), (int) it->first.size()); + char * value = curl_escape(it->second.c_str(),(int) it->second.size()); + *content += key; + *content += '='; + *content += value; + *content += '&'; + curl_free(key); + curl_free(value); + } + } + + void appendUrlParams(std::map const & params, std::string* url) const + { + if(params.empty()) { + return; + } + std::string content; + this->makeUrlencodedForm(params, &content); + bool url_has_param = false; + for (const auto& ch : *url) { + if (ch == '?') { + url_has_param = true; + break; + } + } + if (url_has_param) { + url->append("&"); + } else { + url->append("?"); + } + url->append(content); + } + + void appendHeaders(std::map const & headers, curl_slist ** slist) const + { + std::ostringstream ostr; + std::map::const_iterator it; + for(it=headers.begin(); it!=headers.end(); it++) + { + ostr << it->first << ":" << it->second; + *slist = curl_slist_append(*slist, ostr.str().c_str()); + ostr.str(""); + } + } + + public: + HttpClient() = default; + + HttpClient(const HttpClient &) = delete; + HttpClient & operator=(const HttpClient &) = delete; + + void setConnectTimeout(int connect_timeout) + { + this->connect_timeout = connect_timeout; + } + + void setSocketTimeout(int socket_timeout) + { + this->socket_timeout = socket_timeout; + } + + void setDebug(bool debug) + { + this->debug = debug; + } + + int get( + std::string url, + std::map const * params, + std::map const * headers, + std::string * response) const + { + CURL * curl = curl_easy_init(); + struct curl_slist * slist = NULL; + if (headers) { + this->appendHeaders(*headers, &slist); + } + if (params) { + this->appendUrlParams(*params, &url); + } + + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, slist); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, onWriteData); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, (void *) response); + curl_easy_setopt(curl, CURLOPT_NOSIGNAL, true); + curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT_MS, this->connect_timeout); + curl_easy_setopt(curl, CURLOPT_TIMEOUT_MS, this->socket_timeout); + curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, false); + curl_easy_setopt(curl, CURLOPT_SSL_VERIFYHOST, false); + curl_easy_setopt(curl, CURLOPT_VERBOSE, this->debug); + + int status_code = curl_easy_perform(curl); + + curl_easy_cleanup(curl); + curl_slist_free_all(slist); + + return status_code; + } + + int post( + std::string url, + std::map const * params, + const std::string & body, + std::map const * headers, + std::string * response) const + { + struct curl_slist * slist = NULL; + CURL * curl = curl_easy_init(); + if (headers) { + this->appendHeaders(*headers, &slist); + } + if (params) { + this->appendUrlParams(*params, &url); + } + + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, slist); + curl_easy_setopt(curl, CURLOPT_POST, true); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, body.size()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, onWriteData); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, (void *) response); + curl_easy_setopt(curl, CURLOPT_NOSIGNAL, true); + curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT_MS, this->connect_timeout); + curl_easy_setopt(curl, CURLOPT_TIMEOUT_MS, this->socket_timeout); + curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, false); + curl_easy_setopt(curl, CURLOPT_SSL_VERIFYHOST, false); + curl_easy_setopt(curl, CURLOPT_VERBOSE, this->debug); + + int status_code = curl_easy_perform(curl); + + curl_easy_cleanup(curl); + curl_slist_free_all(slist); + + return status_code; + } + + /** + * application/x-www-form-urlencoded + * @param url + * @param params + * @param data + * @param headers + * @param response + * @return + */ + int post( + std::string url, + std::map const * params, + std::map const & data, + std::map const * headers, + std::string * response) const + { + std::string body; + this->makeUrlencodedForm(data, &body); + return this->post(std::move(url), params, body, headers, response); + } + + /** + * application/json + * @param url + * @param params + * @param data + * @param headers + * @param response + * @return + */ + int post( + std::string url, + std::map const * params, + Json::Value const & data, + std::map const * headers, + std::string * response) const + { + std::string body; + Json::StreamWriterBuilder swb; + std::unique_ptr writer(swb.newStreamWriter()); + std::ostringstream os; + writer->write(data, &os); + body = os.str(); + std::map temp_headers; + if (headers != nullptr) { + for (const auto & iter : *headers) { + temp_headers[iter.first] = iter.second; + } + } + if (temp_headers.find("Content-Type") == temp_headers.end()) { + temp_headers["Content-Type"] = "application/json"; + } + return this->post(url.c_str(), params, body, &temp_headers, response); + } + + /** + * application/x-www-form-urlencoded + * all type data + * @param url + * @param params + * @param data + * @param headers + * @param response + * @return + */ + int post_form( + std::string url, + std::map const * params, + Json::Value const & data, + std::map const * headers, + std::string * response) const + { + std::string body; + body.clear(); + Json::Value::Members mem = data.getMemberNames(); + for (auto iter = mem.begin(); iter != mem.end(); iter++) { + std::string str = ""; + char * curl_escape_value; + char * key = curl_escape((*iter).c_str(), (int)((*iter).size())); + body += key; + body += '='; + Json::Value jsonValue = data[*iter]; + switch(jsonValue.type()) { + case Json::realValue: + body += std::to_string(data[*iter].asDouble()); + break; + case Json::intValue: + body += std::to_string(data[*iter].asInt64()); + break; + case Json::booleanValue: + body += std::to_string(data[*iter].asBool()); + break; + case Json::stringValue: + str = data[*iter].asString(); + curl_escape_value = curl_escape(str.c_str(), (int)(str.size())); + body += curl_escape_value; + curl_free(curl_escape_value); + break; + default: + break; + } + body += '&'; + curl_free(key); + } + return this->post(std::move(url), params, body, headers, response); + } + + + int post( + std::string url, + std::map const * params, + std::map const * headers, + std::string * response) const + { + const static std::string EMPTY_STRING; + return this->post(std::move(url), params, EMPTY_STRING, headers, response); + } + }; + +} + +#endif diff --git a/third/include/aip-cpp-sdk/base/utils.h b/third/include/aip-cpp-sdk/base/utils.h new file mode 100644 index 0000000..35e9f18 --- /dev/null +++ b/third/include/aip-cpp-sdk/base/utils.h @@ -0,0 +1,283 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ +#ifndef __AIP_UTILS_H__ +#define __AIP_UTILS_H__ + +#include +#include +#include +#include +#include +#include +#include + +const int __BCE_VERSION__ = 1; +const int __BCE_EXPIRE__ = 1800; + +namespace aip { + + template + std::basic_istream& getall(std::basic_istream& input, + std::basic_string& str) { + std::ostringstream oss; + oss << input.rdbuf(); + str.assign(oss.str()); + return input; + } + + inline int get_file_content(const char *filename, std::string* out) { + std::ifstream in(filename, std::ios::in | std::ios::binary); + if (in) { + getall(in, *out); + return 0; + } else { + return -1; + } + } + + inline std::string to_upper(std::string src) + { + std::transform(src.begin(), src.end(), src.begin(), [](unsigned char c) { return std::toupper(c); }); + return src; + } + + + inline std::string to_lower(std::string src) + { + std::transform(src.begin(), src.end(), src.begin(), [](unsigned char c) { return std::tolower(c); }); + return src; + } + + inline std::string to_hex(unsigned char c, bool lower = false) + { + const std::string hex = "0123456789ABCDEF"; + + std::stringstream ss; + ss << hex[c >> 4] << hex[c & 0xf]; + + return lower ? to_lower(ss.str()) : ss.str(); + } + + inline time_t now() + { + return time(NULL); + } + + std::string utc_time(time_t timestamp) + { + struct tm result_tm; + char buffer[32]; + +#ifdef _WIN32 + gmtime_s(&result_tm, ×tamp); +#else + gmtime_r(×tamp, &result_tm); +#endif + + size_t size = strftime(buffer, 32, "%Y-%m-%dT%H:%M:%SZ", &result_tm); + + return std::string(buffer, size); + } + + void url_parse( + const std::string & url, + std::map & params) + { + int pos = (int)url.find("?"); + if (pos != -1) + { + int key_start = pos + 1, + key_len = 0, + val_start = 0; + for (int i = key_start; i <= (int)url.size(); ++i) + { + switch (url[i]) + { + case '=': + key_len = i - key_start; + val_start = i + 1; + break; + case '\0': + case '&': + if (key_len != 0) + { + params[url.substr(key_start, key_len)] = url.substr(val_start, i - val_start); + key_start = i + 1; + key_len = 0; + } + break; + default: + break; + } + } + } + } + + std::string url_encode(const std::string & input, bool encode_slash=true) + { + std::stringstream ss; + const char *str = input.c_str(); + + for (uint32_t i = 0; i < input.size(); i++) + { + unsigned char c = str[i]; + if (isalnum(c) || c == '_' || c == '-' || c == '~' || c == '.' || (!encode_slash && c == '/')) + { + ss << c; + } + else + { + ss << "%" << to_hex(c); + } + } + + return ss.str(); + } + + std::string canonicalize_params(std::map & params) + { + std::vector v; + v.reserve(params.size()); + + for (auto & it : params) { + v.push_back(url_encode(it.first) + "=" + url_encode(it.second)); + } + std::sort(v.begin(), v.end()); + + std::string result; + for (auto & it : v) + { + result.append((result.empty() ? "" : "&") + it); + } + return result; + } + + std::string canonicalize_headers(std::map & headers) + { + std::vector v; + v.reserve(headers.size()); + + for (auto & it : headers) { + v.push_back(url_encode(to_lower(it.first)) + ":" + url_encode(it.second)); + } + std::sort(v.begin(), v.end()); + + std::string result; + for (auto & it : v) + { + result.append((result.empty() ? "" : "\n") + it); + } + return result; + } + + std::string get_headers_keys(std::map & headers) + { + std::vector v; + v.reserve(headers.size()); + + for (auto & it : headers) { + v.push_back(to_lower(it.first)); + } + + std::string result; + for (auto & it : v) + { + result.append((result.empty() ? "" : ";") + it); + } + return result; + } + + std::string get_host(const std::string & url) + { + int pos = (int)url.find("://") + 3; + return url.substr( + pos, + url.find('/', pos) - pos + ); + } + + + std::string get_path(const std::string & url) + { + int path_start = (int)url.find('/', url.find("://") + 3); + int path_end = (int)url.find('?'); + path_end = path_end == -1 ? (int)url.size() : path_end; + + return url.substr(path_start, path_end - path_start); + } + + std::string hmac_sha256( + const std::string & src, + const std::string & sk) + { + const EVP_MD *evp_md = EVP_sha256(); + unsigned char md[EVP_MAX_MD_SIZE]; + unsigned int md_len = 0; + + if (HMAC(evp_md, + reinterpret_cast(sk.data()), (int)sk.size(), + reinterpret_cast(src.data()), src.size(), + md, &md_len) == NULL) + { + return ""; + } + + std::stringstream ss; + for (int i = 0; i < (int)md_len; ++i) + { + ss << to_hex(md[i], true); + } + + return ss.str(); + } + + void sign( + std::string method, + std::string & url, + std::map & params, + std::map & headers, + std::string & ak, + std::string & sk) + { + url_parse(url, params); + headers["Host"] = get_host(url); + std::string timestamp = utc_time(now()); + headers["x-bce-date"] = timestamp; + + std::stringstream ss; + ss << "bce-auth-v" << __BCE_VERSION__ << "/" << ak << "/" + << timestamp << "/" << __BCE_EXPIRE__; + + std::string val = ss.str(); + std::string sign_key = hmac_sha256(val, sk); + + ss.str(""); + ss << to_upper(method) << '\n' << url_encode(get_path(url), false) + << '\n' << canonicalize_params(params) + << '\n' << canonicalize_headers(headers); + + std::string signature = hmac_sha256(ss.str(), sign_key); + + ss.str(""); + ss << "bce-auth-v" << __BCE_VERSION__ << "/" << ak << "/" + << timestamp << "/" << __BCE_EXPIRE__ << "/" + << get_headers_keys(headers) << "/" << signature; + + headers["authorization"] = ss.str(); + } + +} + +#endif diff --git a/third/include/aip-cpp-sdk/body_analysis.h b/third/include/aip-cpp-sdk/body_analysis.h new file mode 100644 index 0000000..af3b2ee --- /dev/null +++ b/third/include/aip-cpp-sdk/body_analysis.h @@ -0,0 +1,220 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ + +#ifndef __AIP_BODY_ANALYSIS_H__ +#define __AIP_BODY_ANALYSIS_H__ + +#include "base/base.h" + +namespace aip { + + class Bodyanalysis : public AipBase + { + public: + + std::string _body_analysis_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/body_analysis"; + std::string _body_attr_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/body_attr"; + std::string _body_num_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/body_num"; + std::string _driver_behavior_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/driver_behavior"; + std::string _body_seg_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/body_seg"; + std::string _gesture_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/gesture"; + std::string _body_tracking_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/body_tracking"; + std::string _hand_analysis_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/hand_analysis"; + std::string _body_danger_v1 = + "https://aip.baidubce.com/rest/2.0/video-classify/v1/body_danger"; + std::string _fingertip_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/fingertip"; + + Bodyanalysis(const std::string & app_id, const std::string & ak, const std::string & sk) + : AipBase(app_id, ak, sk) + { + } + + + /** + * 人体关键点识别 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/BODY/0k3cpyxme + */ + Json::Value body_analysis_v1( + std::string const &image) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_body_analysis_v1, null, data, null); + + return result; + } + + /** + * 人体检测与属性识别 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/BODY/Ak3cpyx6v + */ + Json::Value body_attr_v1( + std::string const &image, + const std::map &options) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_body_attr_v1, null, data, null); + + return result; + } + + /** + * 人流量统计 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/BODY/7k3cpyy1t + */ + Json::Value body_num_v1( + std::string const &image, + const std::map &options) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_body_num_v1, null, data, null); + + return result; + } + + /** + * 驾驶行为分析 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/BODY/Nk3cpywct + */ + Json::Value driver_behavior_v1( + std::string const &image, + const std::map &options) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_driver_behavior_v1, null, data, null); + + return result; + } + + /** + * 人像分割 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/BODY/Fk3cpyxua + */ + Json::Value body_seg_v1( + std::string const &image, + const std::map &options) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_body_seg_v1, null, data, null); + + return result; + } + + /** + * 手势识别 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/BODY/4k3cpywrv + */ + Json::Value gesture_v1( + std::string const &image) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_gesture_v1, null, data, null); + + return result; + } + + /** + * 人流量统计(动态版) + * 接口使用文档链接: https://ai.baidu.com/ai-doc/BODY/wk3cpyyog + */ + Json::Value body_tracking_v1( + std::string const &dynamic, + std::string const &image, + Json::Value & options) + { + Json::Value data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["dynamic"] = dynamic; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_body_tracking_v1, data, &headers); + + return result; + } + + /** + * 手部关键点识别 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/BODY/Kk3cpyxeu + */ + Json::Value hand_analysis_v1( + std::string const &image) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_hand_analysis_v1, null, data, null); + + return result; + } + + /** + * 危险行为识别 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/BODY/uk3cpywke + */ + Json::Value body_danger_v1( + std::string const &video_data) + { + std::map data; + data["data"] = base64_encode(video_data.c_str(), (int) video_data.size()); + Json::Value result = + this->request(_body_danger_v1, null, data, null); + + return result; + } + + /** + * 指尖检测 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/BODY/Jk7ir38ut + */ + Json::Value fingertip_v1( + std::string const &image) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_fingertip_v1, null, data, null); + + return result; + } + + }; +} +#endif \ No newline at end of file diff --git a/third/include/aip-cpp-sdk/content_censor.h b/third/include/aip-cpp-sdk/content_censor.h new file mode 100644 index 0000000..c15cade --- /dev/null +++ b/third/include/aip-cpp-sdk/content_censor.h @@ -0,0 +1,327 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ + +#ifndef __AIP_CONTENTCENSOR_H__ +#define __AIP_CONTENTCENSOR_H__ + +#include "base/base.h" + +namespace aip { + +class Contentcensor: public AipBase +{ +public: + std::string _img_censor_user_defined_v2 = + "https://aip.baidubce.com/rest/2.0/solution/v1/img_censor/v2/user_defined"; + std::string _text_censor_user_defined_v2 = + "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined"; + std::string _live_save_v1 = "https://aip.baidubce.com/rest/2.0/solution/v1/live/v1/config/save"; + std::string _live_stop_v1 = "https://aip.baidubce.com/rest/2.0/solution/v1/live/v1/config/stop"; + std::string _live_view_v1 = "https://aip.baidubce.com/rest/2.0/solution/v1/live/v1/config/view"; + std::string _live_pull_v1 = "https://aip.baidubce.com/rest/2.0/solution/v1/live/v1/audit/pull"; + std::string _video_censor_submit_v1 = "https://aip.baidubce.com/rest/2.0/solution/v1/video_censor/v1/video/submit"; + std::string _video_censor_pull_v1 = "https://aip.baidubce.com/rest/2.0/solution/v1/video_censor/v1/video/pull"; + std::string _async_voice_submit_v1 = "https://aip.baidubce.com/rest/2.0/solution/v1/async_voice/submit"; + std::string _async_voice_pull_v1 = "https://aip.baidubce.com/rest/2.0/solution/v1/async_voice/pull"; + std::string _document_censor_submit_url = "https://aip.baidubce.com/rest/2.0/solution/v1/solution/document/v1/submit"; + std::string _document_censor_pull_url = "https://aip.baidubce.com/rest/2.0/solution/v1/solution/document/v1/pull"; + + Contentcensor(const std::string & app_id, const std::string & ak, const std::string & sk) + : AipBase(app_id, ak, sk) + { + } + + /** + * 内容审核平台-图像 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/ANTIPORN/jk42xep4e + */ + Json::Value img_censor_user_defined_v2_img(std::string const &image, const Json::Value & options) + { + Json::Value data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_img_censor_user_defined_v2, data, &headers); + + return result; + } + + /** + * 内容审核平台-图像 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/ANTIPORN/jk42xep4e + */ + Json::Value img_censor_user_defined_v2_url(std::string const &imgUrl, const Json::Value & options) + { + Json::Value data; + data["imgUrl"] = imgUrl; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_img_censor_user_defined_v2, data, &headers); + + return result; + } + + /** + * 内容审核平台-文本 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/ANTIPORN/Rk3h6xb3i + */ + Json::Value text_censor_user_defined_v2(std::string const &text) + { + std::map data; + data["text"] = text; + Json::Value result = + this->request(_text_censor_user_defined_v2, null, data, null); + return result; + } + + /** + * 内容审核平台-直播流(新增任务) + * 接口使用文档链接: https://ai.baidu.com/ai-doc/ANTIPORN/mkxlraoz5 + */ + Json::Value live_save_v1(std::string const &streamUrl, std::string const &streamType, + std::string const &extId, long long const &startTime, + long long const &endTime, std::string const &streamName, const Json::Value & options) + { + Json::Value data; + data["streamUrl"] = streamUrl; + data["streamType"] = streamType; + data["extId"] = extId; + data["startTime"] = startTime; + data["endTime"] = endTime; + data["streamName"] = streamName; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_live_save_v1, data, &headers); + + return result; + } + + /** + * 内容审核平台-直播流(删除任务) + * 接口使用文档链接: https://ai.baidu.com/ai-doc/ANTIPORN/Ckxls2owb + */ + Json::Value live_stop_v1( + std::string const &taskId, + const Json::Value & options) + { + Json::Value data; + data["taskId"] = taskId; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_live_stop_v1, data, &headers); + + return result; + } + + /** + * 内容审核平台-直播流(查看配置) + * 接口使用文档链接: https://ai.baidu.com/ai-doc/ANTIPORN/ckxls6tl1 + */ + Json::Value live_view_v1( + std::string const &taskId, + const Json::Value & options) + { + Json::Value data; + data["taskId"] = taskId; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_live_view_v1, data, &headers); + + return result; + } + + /** + * 内容审核平台-直播流(获取结果) + * 接口使用文档链接: https://ai.baidu.com/ai-doc/ANTIPORN/Pkxlshd1s + */ + Json::Value live_pull_v1( + std::string const &taskId, + const Json::Value & options) + { + Json::Value data; + data["taskId"] = taskId; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_live_view_v1, data, &headers); + + return result; + } + + + /** + * 内容审核平台-长视频(提交任务) + * 接口使用文档链接: https://ai.baidu.com/ai-doc/ANTIPORN/bksy7ak30 + */ + Json::Value video_censor_submit_v1( + std::string const &url, + std::string const &extId, + const Json::Value & options) + { + Json::Value data; + data["url"] = url; + data["extId"] = extId; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_video_censor_submit_v1, data, &headers); + + return result; + } + + /** + * 内容审核平台-长视频(获取结果) + * 接口使用文档链接: https://ai.baidu.com/ai-doc/ANTIPORN/jksy7j3jv + */ + Json::Value video_censor_pull_v1( + std::string const &taskId, + const Json::Value & options) + { + Json::Value data; + data["taskId"] = taskId; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_video_censor_pull_v1, data, &headers); + + return result; + } + + /** + * 音频文件异步审核 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/ANTIPORN/akxlple3t + */ + Json::Value async_voice_submit_v1( + std::string const &url, std::string const &fmt, int rate, + const Json::Value & options) + { + Json::Value data; + data["url"] = url; + data["fmt"] = fmt; + data["rate"] = rate; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_async_voice_submit_v1, data, &headers); + + return result; + } + + /** + * 音频文件异步审核-查询 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/ANTIPORN/jkxlpxllo + */ + Json::Value async_voice_pull_v1_taskid( + std::string const &taskId) + { + std::map data; + data["taskId"] = taskId; + Json::Value result = + this->request(_async_voice_pull_v1, null, data, null); + return result; + } + + /** + * 音频文件异步审核-查询 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/ANTIPORN/jkxlpxllo + */ + Json::Value async_voice_pull_v1_audioid( + std::string const &audioId) + { + std::map data; + data["audioId"] = audioId; + Json::Value result = + this->request(_async_voice_pull_v1, null, data, null); + return result; + } + + /** + * 文档审核-提交任务 + * https://ai.baidu.com/ai-doc/ANTIPORN/2l8484xvl + */ + Json::Value document_censor_file_submit( + std::string const & file_name, + std::string const & document, + const std::map & options) + { + std::map data; + + data["fileBase64"] = base64_encode(document.c_str(), (int) document.size()); + data["fileName"] = file_name; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_document_censor_submit_url, null, data, null); + + return result; + } + + /** + * 文档审核-提交任务 + * https://ai.baidu.com/ai-doc/ANTIPORN/2l8484xvl + */ + Json::Value document_censor_url_submit( + std::string const & file_name, + std::string const & url, + const std::map & options) + { + std::map data; + + data["url"] = url; + data["fileName"] = file_name; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_document_censor_submit_url, null, data, null); + + return result; + } + + /** + * 文档审核-拉取结果 + * https://ai.baidu.com/ai-doc/ANTIPORN/4l848df5n + */ + Json::Value document_censor_pull( + std::string const & task_id, + const std::map & options) + { + std::map data; + + data["taskId"] = task_id; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_document_censor_pull_url, null, data, null); + + return result; + } + +}; +} +#endif diff --git a/third/include/aip-cpp-sdk/face.h b/third/include/aip-cpp-sdk/face.h new file mode 100644 index 0000000..a46c0a3 --- /dev/null +++ b/third/include/aip-cpp-sdk/face.h @@ -0,0 +1,1074 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ + +#ifndef __AIP_FACE_H__ +#define __AIP_FACE_H__ + +#include "base/base.h" + +namespace aip { + + class Face: public AipBase + { + public: + + std::string _faceverify = + "https://aip.baidubce.com/rest/2.0/face/v4/faceverify"; + + std::string _detect = + "https://aip.baidubce.com/rest/2.0/face/v2/detect"; + + std::string _match = + "https://aip.baidubce.com/rest/2.0/face/v2/match"; + + std::string _identify = + "https://aip.baidubce.com/rest/2.0/face/v2/identify"; + + std::string _verify = + "https://aip.baidubce.com/rest/2.0/face/v2/verify"; + + std::string _user_add = + "https://aip.baidubce.com/rest/2.0/face/v2/faceset/user/add"; + + std::string _user_update = + "https://aip.baidubce.com/rest/2.0/face/v2/faceset/user/update"; + + std::string _user_delete = + "https://aip.baidubce.com/rest/2.0/face/v2/faceset/user/delete"; + + std::string _user_get = + "https://aip.baidubce.com/rest/2.0/face/v2/faceset/user/get"; + + std::string _group_getlist = + "https://aip.baidubce.com/rest/2.0/face/v2/faceset/group/getlist"; + + std::string _group_getusers = + "https://aip.baidubce.com/rest/2.0/face/v2/faceset/group/getusers"; + + std::string _group_adduser = + "https://aip.baidubce.com/rest/2.0/face/v2/faceset/group/adduser"; + + std::string _group_deleteuser = + "https://aip.baidubce.com/rest/2.0/face/v2/faceset/group/deleteuser"; + + std::string _face_verify_v4 = + "https://aip.baidubce.com/rest/2.0/face/v4/mingjing/verify"; + + std::string _face_match_v4 = + "https://aip.baidubce.com/rest/2.0/face/v4/mingjing/match"; + + std::string _online_picture_live_v4 = "https://aip.baidubce.com/rest/2.0/face/v4/faceverify"; + + std::string _faceliveness_sessioncode_v1 = + "https://aip.baidubce.com/rest/2.0/face/v1/faceliveness/sessioncode"; + std::string _faceliveness_verify_v1 = + "https://aip.baidubce.com/rest/2.0/face/v1/faceliveness/verify"; + std::string _face_detect_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/detect"; + std::string _face_match_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/match"; + std::string _face_search_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/search"; + std::string _face_faceset_user_add_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/faceset/user/add"; + std::string _face_faceset_user_update_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/faceset/user/update"; + std::string _face_faceset_user_delete_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/faceset/user/delete"; + std::string _face_faceset_user_get_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/faceset/user/get"; + std::string _face_faceset_group_getlist_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/faceset/group/getlist"; + std::string _face_faceset_group_getusers_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/faceset/group/getusers"; + std::string _face_faceset_user_copy_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/faceset/user/copy"; + std::string _face_fasetset_face_getlist_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/faceset/face/getlist"; + std::string _face_faceset_group_add_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/faceset/group/add"; + std::string _face_faceset_group_delete_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/faceset/group/delete"; + std::string _face_faceset_face_delete_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/faceset/face/delete"; + std::string _face_faceverify_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/faceverify"; + std::string _face_person_idmatch_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/person/idmatch"; + std::string _face_multi_search_v3 = + "https://aip.baidubce.com/rest/2.0/face/v3/multi-search"; + std::string _face_merge_v1 = + "https://aip.baidubce.com/rest/2.0/face/v1/merge"; + std::string _face_skin_smooth_v1 = + "https://aip.baidubce.com/rest/2.0/face/v1/editattr"; + std::string _face_landmark_v1 = + "https://aip.baidubce.com/rest/2.0/face/v1/landmark"; + std::string _face_scene_faceset_user_add = + "https://aip.baidubce.com/rest/2.0/face/scene/faceset/user/add"; + std::string _face_scene_faceset_user_update = + "https://aip.baidubce.com/rest/2.0/face/scene/faceset/user/update"; + std::string _face_scene_faceset_group_add = + "https://aip.baidubce.com/rest/2.0/face/scene/faceset/group/add"; + std::string _face_capture_search = + "https://aip.baidubce.com/rest/2.0/face/capture/search"; + std::string _face_idmatch_date_v4 = + "https://aip.baidubce.com/rest/2.0/face/v4/idmatch_date"; + std::string _face_verify_date_v4 = + "https://aip.baidubce.com/rest/2.0/face/v4/verify_date"; + + Face(const std::string & app_id, const std::string & ak, const std::string & sk): AipBase(app_id, ak, sk) + { + } + + std::string vector_join_base64(const std::vector & v_images) { + std::string images; + size_t count = v_images.size(); + for (size_t i = 0; i < count;i++) + { + std::string image = v_images[i]; + images += base64_encode(image.c_str(), (int) image.size()); + if (i != count) { + images += ","; + } + + } + return images; + } + + /** + * detect + * API文档: https://ai.baidu.com/ai-doc/FACE/fk3co86lr + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * max_face_num 最多处理人脸数目,默认值1 + * face_fields 包括age,beauty,expression,faceshape,gender,glasses,landmark,race,qualities信息,逗号分隔,默认只返回人脸框、概率和旋转角度 + */ + Json::Value detect(std::string const & image, const Json::Value & options) + { + Json::Value data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_detect, data, &headers); + + return result; + } + + /** + * match + * @param images vector多图图像文件二进制内容,vector中每一项可以使用aip::get_file_content函数获取 + * options 可选参数: + * ext_fields 返回质量信息,取值固定:目前支持qualities(质量检测)。(对所有图片都会做改处理) + * image_liveness 返回的活体信息,“faceliveness,faceliveness” 表示对比对的两张图片都做活体检测;“,faceliveness” 表示对第一张图片不做活体检测、第二张图做活体检测;“faceliveness,” 表示对第一张图片做活体检测、第二张图不做活体检测;
**注:需要用于判断活体的图片,图片中的人脸像素面积需要不小于100px\*100px,人脸长宽与图片长宽比例,不小于1/3** + * types 请求对比的两张图片的类型,示例:“7,13”
**12**表示带水印证件照:一般为带水印的小图,如公安网小图
**7**表示生活照:通常为手机、相机拍摄的人像图片、或从网络获取的人像图片等
**13**表示证件照片:如拍摄的身份证、工卡、护照、学生证等证件图片,**注**:需要确保人脸部分不可太小,通常为100px\*100px + */ + Json::Value match( + const std::vector & images, + const std::map & options) + { + std::map data; + + data["images"] = vector_join_base64(images); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_match, null, data, null); + + return result; + } + + /** + * identify + * @param group_id 用户组id(由数字、字母、下划线组成),长度限制128B,多个用户组id,用逗号分隔 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * ext_fields 特殊返回信息,多个用逗号分隔,取值固定: 目前支持faceliveness(活体检测)。**注:需要用于判断活体的图片,图片中的人脸像素面积需要不小于100px\*100px,人脸长宽与图片长宽比例,不小于1/3** + + * user_top_num 返回用户top数,默认为1,最多返回5个 + */ + Json::Value identify( + std::string const & group_id, + std::string const & image, + const std::map & options) + { + std::map data; + + data["group_id"] = group_id; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_identify, null, data, null); + + return result; + } + + /** + * verify + * @param uid 用户id(由数字、字母、下划线组成),长度限制128B + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * @param group_id 用户组id(由数字、字母、下划线组成),长度限制128B,多个用户组id,用逗号分隔 + * options 可选参数: + * top_num 返回用户top数,默认为1 + * ext_fields 特殊返回信息,多个用逗号分隔,取值固定: 目前支持faceliveness(活体检测)。**注:需要用于判断活体的图片,图片中的人脸像素面积需要不小于100px\*100px,人脸长宽与图片长宽比例,不小于1/3** + + */ + Json::Value verify( + std::string const & uid, + std::string const & image, + std::string const & group_id, + const std::map & options) + { + std::map data; + + data["uid"] = uid; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["group_id"] = group_id; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_verify, null, data, null); + + return result; + } + + /** + * user_add + * @param uid 用户id(由数字、字母、下划线组成),长度限制128B + * @param user_info 用户资料,长度限制256B + * @param group_id 用户组id,标识一组用户(由数字、字母、下划线组成),长度限制128B。如果需要将一个uid注册到多个group下,group\_id需要用多个逗号分隔,每个group_id长度限制为48个英文字符。**注:group无需单独创建,注册用户时则会自动创建group。**
**产品建议**:根据您的业务需求,可以将需要注册的用户,按照业务划分,分配到不同的group下,例如按照会员手机尾号作为groupid,用于刷脸支付、会员计费消费等,这样可以尽可能控制每个group下的用户数与人脸数,提升检索的准确率 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * action_type 参数包含append、replace。**如果为“replace”,则每次注册时进行替换replace(新增或更新)操作,默认为append操作**。例如:uid在库中已经存在时,对此uid重复注册时,新注册的图片默认会**追加**到该uid下,如果手动选择`action_type:replace`,则会用新图替换库中该uid下所有图片。 + */ + Json::Value user_add( + std::string const & uid, + std::string const & user_info, + std::string const & group_id, + std::string const & image, + const std::map & options) + { + std::map data; + + data["uid"] = uid; + data["user_info"] = user_info; + data["group_id"] = group_id; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_user_add, null, data, null); + + return result; + } + + /** + * user_update + * @param uid 用户id(由数字、字母、下划线组成),长度限制128B + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * @param user_info 用户资料,长度限制256B + * @param group_id 更新指定groupid下uid对应的信息 + * options 可选参数: + * action_type 目前仅支持replace,uid不存在时,不报错,会自动变为注册操作;未选择该参数时,如果uid不存在会提示错误 + */ + Json::Value user_update( + std::string const & uid, + std::string const & image, + std::string const & user_info, + std::string const & group_id, + const std::map & options) + { + std::map data; + + data["uid"] = uid; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["user_info"] = user_info; + data["group_id"] = group_id; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_user_update, null, data, null); + + return result; + } + + /** + * user_delete + * @param uid 用户id(由数字、字母、下划线组成),长度限制128B + * @param group_id 删除指定groupid下uid对应的信息 + * options 可选参数: + */ + Json::Value user_delete( + std::string const & uid, + std::string const & group_id, + const std::map & options) + { + std::map data; + + data["uid"] = uid; + data["group_id"] = group_id; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_user_delete, null, data, null); + + return result; + } + + /** + * user_get + * @param uid 用户id(由数字、字母、下划线组成),长度限制128B + * options 可选参数: + * group_id 选择指定group_id则只查找group列表下的uid内容,如果不指定则查找所有group下对应uid的信息 + */ + Json::Value user_get( + std::string const & uid, + const std::map & options) + { + std::map data; + + data["uid"] = uid; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_user_get, null, data, null); + + return result; + } + + /** + * group_getlist + * options 可选参数: + * start 默认值0,起始序号 + * end 返回数量,默认值100,最大值1000 + */ + Json::Value group_getlist( + const std::map & options) + { + std::map data; + + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_group_getlist, null, data, null); + + return result; + } + + /** + * group_getusers + * @param group_id 用户组id(由数字、字母、下划线组成),长度限制128B + * options 可选参数: + * start 默认值0,起始序号 + * end 返回数量,默认值100,最大值1000 + */ + Json::Value group_getusers( + std::string const & group_id, + const std::map & options) + { + std::map data; + + data["group_id"] = group_id; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_group_getusers, null, data, null); + + return result; + } + + /** + * group_adduser + * @param group_id 用户组id(由数字、字母、下划线组成),长度限制128B,多个用户组id,用逗号分隔 + * @param uid 用户id(由数字、字母、下划线组成),长度限制128B + * @param src_group_id 从指定group里复制信息 + * options 可选参数: + */ + Json::Value group_adduser( + std::string const & group_id, + std::string const & uid, + std::string const & src_group_id, + const std::map & options) + { + std::map data; + + data["group_id"] = group_id; + data["uid"] = uid; + data["src_group_id"] = src_group_id; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_group_adduser, null, data, null); + + return result; + } + + /** + * group_deleteuser + * @param group_id 用户组id(由数字、字母、下划线组成),长度限制128B,多个用户组id,用逗号分隔 + * @param uid 用户id(由数字、字母、下划线组成),长度限制128B + * options 可选参数: + */ + Json::Value group_deleteuser( + std::string const & group_id, + std::string const & uid, + const std::map & options) + { + std::map data; + + data["group_id"] = group_id; + data["uid"] = uid; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_group_deleteuser, null, data, null); + + return result; + } + + /** + * 人脸 - 人脸实名认证V4 + * 基于姓名和身份证号,调取公安权威数据源人脸图,将当前获取的人脸图片,与此公安数据源人脸图进行对比,得出比对分数,并基于此进行业务判断是否为同一人 + * @param idCardNumber 身份证件号 + * @param name 姓名(需要是 utf8 编码) + * @param image 图片信息(数据大小应小于10M 分辨率应小于1920*1080),5.2版本SDK请求时已包含在加密数据data中,无需额外传入 + * options 可选参数: + * quality_control 质量控制参数 + */ + Json::Value faceMingJingVerify( + const std::string& idCardNumber, + const std::string& name, + std::string* image, + std::map options) + { + std::string access_token = this->getAccessToken(); + + Json::Value data; + data["id_card_number"] = idCardNumber; + data["name"] = name; + if (image != nullptr) { + data["image"] = *image; + } + + std::map< std::string,std::string >::iterator it ; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + std::string mid = "?access_token="; + std::string url = _face_verify_v4 + mid + access_token; + Json::Value result = + this->request_com(url, data); + + return result; + } + + /** + * 人脸 - 人脸对比V4 + * 用于比对多张图片中的人脸相似度并返回两两比对的得分,可用于判断两张脸是否是同一人的可能性大小 + * @param image 图片信息(数据大小应小于10M 分辨率应小于1920*1080),5.2版本SDK请求时已包含在加密数据data中,无需额外传入 + * @param imageType 图片类型 + * @param registerImage 图片信息(总数据大小应小于10M),图片上传方式根据image_type来判断。本图片特指客户服务器上传图片,非加密图片Base64值 + * @param registerImageType 图片类型 + * options 可选参数 + */ + Json::Value faceMingJingMatch( + std::string * image, + std::string * imageType, + const std::string& registerImage, + const std::string& registerImageType, + std::map options) + { + std::string access_token = this->getAccessToken(); + + Json::Value data; + if (image != nullptr) { + data["image"] = *image; + } + if (imageType != nullptr) { + data["image_type"] = *imageType; + } + data["register_image"] = registerImage; + data["register_image_type"] = registerImageType; + + std::map< std::string,std::string >::iterator it ; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + std::string mid = "?access_token="; + std::string url = _face_match_v4 + mid + access_token; + Json::Value result = + this->request_com(url, data); + + return result; + } + + /** + * 人脸 - 在线图片活体V4 + * 基于单张图片,判断图片中的人脸是否为二次翻拍 + * @param sdkVersion sdk版本 + * options 可选参数 + */ + Json::Value onlinePictureLiveV4( + const std::string& sdkVersion, + std::vector& imageList, + std::map options) + { + std::string access_token = this->getAccessToken(); + + Json::Value data; + data["sdk_version"] = sdkVersion; + Json::Value imageListJson; + for (std::string image : imageList) { + imageListJson.append(image); + } + data["image_list"] = imageListJson; + + std::map< std::string,std::string >::iterator it ; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + std::string mid = "?access_token="; + std::string url = _online_picture_live_v4 + mid + access_token; + Json::Value result = + this->request_com(url, data); + + return result; + } + + /** + * 随机校验码 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Ikrycq2k2 + */ + Json::Value faceliveness_sessioncode_v1(const Json::Value & options) + { + Json::Value data; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_faceliveness_sessioncode_v1, data, &headers); + + return result; + } + + /** + * H5视频活体检测 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Ikrycq2k2 + */ + Json::Value faceliveness_verify_v1( + std::string const &video_base64, + const std::map &options) + { + std::map data; + data["video_base64"] = base64_encode(video_base64.c_str(), (int) video_base64.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_faceliveness_verify_v1, null, data, null); + + return result; + } + + /** + * 人脸检测 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/yk37c1u4t + */ + Json::Value face_detect_v3( + std::string const &image, + std::string const &image_type, + const Json::Value & options) + { + Json::Value data; + data["image"] = image; + data["image_type"] = image_type; + merge_json(data, options); + + Json::Value result = this->request_com(_face_detect_v3, data); + return result; + } + + /** + * 人脸对比 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Lk37c1tpf + */ + Json::Value face_match_v3(Json::Value & image_array) + { + Json::Value result = this->request_com(_face_match_v3, image_array); + return result; + } + + /** + * 人脸搜索 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Gk37c1uzc + */ + Json::Value face_search_v3( + std::string const &image, + std::string const &image_type, + std::string const &group_id_list, + const Json::Value & options) + { + Json::Value data; + data["image"] = image; + data["image_type"] = image_type; + data["group_id_list"] = group_id_list; + merge_json(data, options); + + Json::Value result = this->request_com(_face_search_v3, data); + return result; + } + + /** + * 人脸注册 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Gk37c1uzc#%E4%BA%BA%E8%84%B8%E6%B3%A8%E5%86%8C + */ + Json::Value face_faceset_user_add_v3( + std::string const &image, + std::string const &image_type, + std::string const &group_id, + std::string const &user_id, + const Json::Value & options) + { + Json::Value data; + data["image"] = image; + data["image_type"] = image_type; + data["group_id"] = group_id; + data["user_id"] = user_id; + merge_json(data, options); + + Json::Value result = this->request_com(_face_faceset_user_add_v3, data); + return result; + } + + /** + * 人脸更新 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Gk37c1uzc#%E4%BA%BA%E8%84%B8%E6%9B%B4%E6%96%B0 + */ + Json::Value face_faceset_user_update_v3( + std::string const &image, + std::string const &image_type, + std::string const &group_id, + std::string const &user_id, + const Json::Value & options) + { + Json::Value data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["image_type"] = image_type; + data["group_id"] = group_id; + data["user_id"] = user_id; + merge_json(data, options); + + Json::Value result = this->request_com(_face_faceset_user_update_v3, data); + return result; + } + + /** + * 删除用户 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Gk37c1uzc#%E5%88%A0%E9%99%A4%E7%94%A8%E6%88%B7 + */ + Json::Value face_faceset_user_delete_v3( + std::string const &group_id, + std::string const &user_id) + { + Json::Value data; + data["group_id"] = group_id; + data["user_id"] = user_id; + + Json::Value result = this->request_com(_face_faceset_user_delete_v3, data); + return result; + } + + /** + * 用户信息查询 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Gk37c1uzc#%E7%94%A8%E6%88%B7%E4%BF%A1%E6%81%AF%E6%9F%A5%E8%AF%A2 + */ + Json::Value face_faceset_user_get_v3( + std::string const &group_id, + std::string const &user_id) + { + Json::Value data; + data["group_id"] = group_id; + data["user_id"] = user_id; + + Json::Value result = this->request_com(_face_faceset_user_get_v3, data); + return result; + } + + /** + * 获取组列表 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Gk37c1uzc#%E7%BB%84%E5%88%97%E8%A1%A8%E6%9F%A5%E8%AF%A2 + */ + Json::Value face_faceset_group_getlist_v3(uint32_t start, uint32_t length) + { + Json::Value data; + data["start"] = start; + data["length"] = length; + + Json::Value result = this->request_com(_face_faceset_group_getlist_v3, data); + return result; + } + + /** + * 获取用户列表 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Gk37c1uzc#%E8%8E%B7%E5%8F%96%E7%94%A8%E6%88%B7%E5%88%97%E8%A1%A8 + */ + Json::Value face_faceset_group_getusers_v3( + std::string const &group_id, const Json::Value & options) + { + Json::Value data; + data["group_id"] = group_id; + merge_json(data, options); + + Json::Value result = this->request_com(_face_faceset_group_getusers_v3, data); + return result; + } + + /** + * 复制用户 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Gk37c1uzc#%E5%A4%8D%E5%88%B6%E7%94%A8%E6%88%B7 + */ + Json::Value face_faceset_user_copy_v3( + std::string const &user_id, + std::string const &src_group_id, + std::string const &dst_group_id) + { + Json::Value data; + data["user_id"] = user_id; + data["src_group_id"] = src_group_id; + data["dst_group_id"] = dst_group_id; + + Json::Value result = this->request_com(_face_faceset_user_copy_v3, data); + return result; + } + + /** + * 获取用户人脸列表 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Gk37c1uzc#%E8%8E%B7%E5%8F%96%E7%94%A8%E6%88%B7%E4%BA%BA%E8%84%B8%E5%88%97%E8%A1%A8 + */ + Json::Value face_fasetset_face_getlist_v3( + std::string const &user_id, + std::string const &group_id) + { + Json::Value data; + data["user_id"] = user_id; + data["group_id"] = group_id; + + Json::Value result = this->request_com(_face_fasetset_face_getlist_v3, data); + return result; + } + + /** + * 创建用户组 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Gk37c1uzc#%E5%88%9B%E5%BB%BA%E7%94%A8%E6%88%B7%E7%BB%84 + */ + Json::Value face_faceset_group_add_v3( + std::string const &group_id) + { + Json::Value data; + data["group_id"] = group_id; + + Json::Value result = this->request_com(_face_faceset_group_add_v3, data); + return result; + } + + /** + * 删除用户组 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Gk37c1uzc#%E5%88%A0%E9%99%A4%E7%94%A8%E6%88%B7%E7%BB%84 + */ + Json::Value face_faceset_group_delete_v3( + std::string const &group_id) + { + Json::Value data; + data["group_id"] = group_id; + + Json::Value result = this->request_com(_face_faceset_group_delete_v3, data); + return result; + } + + /** + * 删除人脸 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Gk37c1uzc#%E4%BA%BA%E8%84%B8%E5%88%A0%E9%99%A4 + */ + Json::Value face_faceset_face_delete_v3( + long long log_id, + std::string const &user_id, + std::string const &group_id, + std::string const &face_token) + { + Json::Value data; + data["log_id"] = log_id; + data["user_id"] = user_id; + data["group_id"] = group_id; + data["face_token"] = face_token; + + Json::Value result = this->request_com(_face_faceset_face_delete_v3, data); + return result; + } + + /** + * 在线活体检测V3 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Zk37c1urr + */ + Json::Value face_faceverify_v3(Json::Value & image_array) + { + Json::Value result = this->request_com(_face_faceverify_v3, image_array); +// +// Json::Value data; +// data["image"] = base64_encode(image.c_str(), (int) image.size()); +// data["image_type"] = image_type; +// merge_json(data, options); +// +// Json::Value result = this->request_com(_face_faceverify_v3, data); + return result; + } + + /** + * 身份证与名字比对 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Tkqahnjtk + */ + Json::Value face_person_idmatch_v3( + std::string const &id_card_number, + std::string const &name) + { + Json::Value data; + data["id_card_number"] = id_card_number; + data["name"] = name; + + Json::Value result = this->request_com(_face_person_idmatch_v3, data); + return result; + } + + /** + * 人脸搜索-M:N识别 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Gk37c1uzc#%E4%BA%BA%E8%84%B8%E6%90%9C%E7%B4%A2-mn-%E8%AF%86%E5%88%AB + */ + Json::Value face_multi_search_v3( + std::string const &image, + std::string const &image_type, + std::string const &group_id_list, + const Json::Value & options) + { + Json::Value data; + data["image"] = image; + data["image_type"] = image_type; + data["group_id_list"] = group_id_list; + merge_json(data, options); + + Json::Value result = this->request_com(_face_multi_search_v3, data); + return result; + } + + /** + * 人脸融合 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/5k37c1ti0 + */ + Json::Value face_merge_v1( + const Json::Value & image_template, + const Json::Value & image_target, + const Json::Value & options) + { + Json::Value data; + data["image_template"] = image_template; + data["image_target"] = image_target; + merge_json(data, options); + + Json::Value result = this->request_com(_face_merge_v1, data); + return result; + } + + /** + * 人脸属性编辑 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/vk6rm5lj5 + */ + Json::Value face_skin_smooth_v1( + std::string const &image, + std::string const &image_type, + std::string const &action_type, + const Json::Value & options) + { + Json::Value data; + data["image"] = image; + data["image_type"] = image_type; + data["action_type"] = action_type; + merge_json(data, options); + + Json::Value result = this->request_com(_face_skin_smooth_v1, data); + return result; + } + + /** + * 人脸关键点检测 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/sk8a5xewt + */ + Json::Value face_landmark_v1( + std::string const &image, + std::string const &image_type, + const Json::Value & options) + { + Json::Value data; + data["image"] = image; + data["image_type"] = image_type; + merge_json(data, options); + + Json::Value result = this->request_com(_face_landmark_v1, data); + return result; + } + + /** + * 场景化(人脸注册) + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Aknhmx6hi#%E4%BA%BA%E8%84%B8%E5%BA%93%E7%AE%A1%E7%90%86%EF%BC%88%E5%9C%BA%E6%99%AF%E5%8C%96%EF%BC%89-%E4%BA%BA%E8%84%B8%E6%B3%A8%E5%86%8C + */ + Json::Value face_scene_faceset_user_add( + std::string const &image, + std::string const &image_type, + std::string const &group_id, + std::string const &user_id, + std::string const &scene_type, + const Json::Value & options) + { + Json::Value data; + data["image"] = image; + data["image_type"] = image_type; + data["group_id"] = group_id; + data["user_id"] = user_id; + data["scene_type"] = scene_type; + merge_json(data, options); + + Json::Value result = this->request_com(_face_scene_faceset_user_add, data); + return result; + } + + /** + * 场景化(人脸更新) + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Aknhmx6hi#%E4%BA%BA%E8%84%B8%E5%BA%93%E7%AE%A1%E7%90%86%EF%BC%88%E5%9C%BA%E6%99%AF%E5%8C%96%EF%BC%89-%E4%BA%BA%E8%84%B8%E6%9B%B4%E6%96%B0 + */ + Json::Value face_scene_faceset_user_update( + std::string const &image, + std::string const &image_type, + std::string const &group_id, + std::string const &user_id, + std::string const &scene_type, + const Json::Value & options) + { + Json::Value data; + data["image"] = image; + data["image_type"] = image_type; + data["group_id"] = group_id; + data["user_id"] = user_id; + data["scene_type"] = scene_type; + merge_json(data, options); + + Json::Value result = this->request_com(_face_scene_faceset_user_update, data); + return result; + } + + /** + * 场景化(创建用户组) + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Aknhmx6hi#%E4%BA%BA%E8%84%B8%E5%BA%93%E7%AE%A1%E7%90%86%EF%BC%88%E5%9C%BA%E6%99%AF%E5%8C%96%EF%BC%89-%E5%88%9B%E5%BB%BA%E7%94%A8%E6%88%B7%E7%BB%84 + */ + Json::Value face_scene_faceset_group_add( + std::string const &group_id, + std::string const &scene_type) + { + Json::Value data; + data["group_id"] = group_id; + data["scene_type"] = scene_type; + + Json::Value result = this->request_com(_face_scene_faceset_group_add, data); + return result; + } + + /** + * 人脸搜索(视频监控) + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/Aknhmx6hi + */ + Json::Value face_capture_search( + std::string const &image, + std::string const &image_type, + std::string const &group_id_list, + const Json::Value & options) + { + Json::Value data; + data["image"] = image; + data["image_type"] = image_type; + data["group_id_list"] = group_id_list; + merge_json(data, options); + + Json::Value result = this->request_com(_face_capture_search, data); + return result; + } + + /** + * 身份证信息及有效期核验接口 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/elav5puig + */ + Json::Value face_idmatch_date_v4( + std::string const &name, + std::string const &id_card_number, + std::string const &start_date, + std::string const &end_date) + { + Json::Value data; + data["name"] = name; + data["id_card_number"] = id_card_number; + data["start_date"] = start_date; + data["end_date"] = end_date; + + Json::Value result = this->request_com(_face_idmatch_date_v4, data); + return result; + } + + /** + * 人脸实名信息及有效期核验 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/FACE/qlav5rwms + */ + Json::Value face_verify_date_v4( + std::string const &name, + std::string const &id_card_number, + std::string const &start_date, + std::string const &end_date, + std::string const &image, + std::string const &image_type, + const Json::Value & options) + { + Json::Value data; + data["name"] = name; + data["id_card_number"] = id_card_number; + data["start_date"] = start_date; + data["end_date"] = end_date; + data["image"] = image; + data["image_type"] = image_type; + merge_json(data, options); + + Json::Value result = this->request_com(_face_verify_date_v4, data); + return result; + } + }; +} +#endif diff --git a/third/include/aip-cpp-sdk/image_censor.h b/third/include/aip-cpp-sdk/image_censor.h new file mode 100644 index 0000000..0a1c09c --- /dev/null +++ b/third/include/aip-cpp-sdk/image_censor.h @@ -0,0 +1,111 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ + +#ifndef __AIP_IMAGECENSOR_H__ +#define __AIP_IMAGECENSOR_H__ + +#include "base/base.h" + +namespace aip { + + class Imagecensor: public AipBase + { + public: + + + std::string _anti_porn = + "https://aip.baidubce.com/rest/2.0/antiporn/v1/detect"; + + std::string _anti_porn_gif = + "https://aip.baidubce.com/rest/2.0/antiporn/v1/detect_gif"; + + std::string _anti_terror = + "https://aip.baidubce.com/rest/2.0/antiterror/v1/detect"; + + + Imagecensor(const std::string & app_id, const std::string & ak, const std::string & sk): AipBase(app_id, ak, sk) + { + } + + /** + * anti_porn + * 该请求用于鉴定图片的色情度。即对于输入的一张图片(可正常解码,且长宽比适宜),输出图片的色情度。目前支持三个维度:色情、性感、正常。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value anti_porn( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_anti_porn, null, data, null); + + return result; + } + + /** + * anti_porn_gif + * 该请求用于鉴定GIF图片的色情度,对于非gif接口,请使用色情识别接口。接口会对图片中每一帧进行识别,并返回所有检测结果中色情值最大的为结果。目前支持三个维度:色情、性感、正常。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value anti_porn_gif( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_anti_porn_gif, null, data, null); + + return result; + } + + /** + * anti_terror + * 该请求用于鉴定图片是否涉暴涉恐。即对于输入的一张图片(可正常解码,且长宽比适宜),输出图片的涉暴涉恐程度。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value anti_terror( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_anti_terror, null, data, null); + + return result; + } + + + }; +} +#endif \ No newline at end of file diff --git a/third/include/aip-cpp-sdk/image_classify.h b/third/include/aip-cpp-sdk/image_classify.h new file mode 100644 index 0000000..4867e84 --- /dev/null +++ b/third/include/aip-cpp-sdk/image_classify.h @@ -0,0 +1,1003 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ + +#ifndef __AIP_IMAGECLASSIFY_H__ +#define __AIP_IMAGECLASSIFY_H__ + +#include "base/base.h" +#include + +namespace aip { + + class Imageclassify : public AipBase { + public: + + std::string _traffic_flow = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/traffic_flow"; + + std::string _vehicle_damage = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/vehicle_damage"; + + std::string _vehicle_seg = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/vehicle_seg"; + + std::string _vehicle_detect = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/vehicle_detect"; + + std::string _vehicle_detect_high = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/vehicle_detect_high"; + + std::string _vehicle_attr = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/vehicle_attr"; + + std::string _redwine = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/redwine"; + + std::string _currency = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/currency"; + + std::string _dishadd = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/dish/add"; + + std::string _combination = + "https://aip.baidubce.com/api/v1/solution/direct/imagerecognition/combination"; + + std::string _dishDelete = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/dish/delete"; + + std::string _ingredient = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/classify/ingredient"; + + std::string _dishSearch = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/dish/search"; + + std::string _mult_object_detect = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/multi_object_detect"; + + std::string _dish_detect = + "https://aip.baidubce.com/rest/2.0/image-classify/v2/dish"; + + std::string _car_detect = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/car"; + + std::string _logo_search = + "https://aip.baidubce.com/rest/2.0/image-classify/v2/logo"; + + std::string _logo_add = + "https://aip.baidubce.com/rest/2.0/realtime_search/v1/logo/add"; + + std::string _logo_delete = + "https://aip.baidubce.com/rest/2.0/realtime_search/v1/logo/delete"; + + std::string _animal_detect = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/animal"; + + std::string _plant_detect = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/plant"; + + std::string _object_detect = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/object_detect"; + + std::string _advanced_general = + "https://aip.baidubce.com/rest/2.0/image-classify/v2/advanced_general"; + + std::string _landmark_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/landmark"; + + std::string _redwine_add_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/redwine/add"; + std::string _redwine_search_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/redwine/search"; + std::string _redwine_delete_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/redwine/delete"; + std::string _redwine_update_v1 = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/redwine/update"; + std::string _vehicle_attr_classify_v2 = + "https://aip.baidubce.com/rest/2.0/image-classify/v2/vehicle_attr"; + + Imageclassify(const std::string &app_id, const std::string &ak, const std::string &sk) : AipBase(app_id, ak, + sk) { + } + + /** + * dish_detect + * 该请求用于菜品识别。即对于输入的一张图片(可正常解码,且长宽比适宜),输出图片的菜品名称、卡路里信息、置信度。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * top_num 返回预测得分top结果数,默认为5 + */ + Json::Value dish_detect( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_dish_detect, null, data, null); + + return result; + } + + /** + * car_detect + * 该请求用于检测一张车辆图片的具体车型。即对于输入的一张图片(可正常解码,且长宽比适宜),输出图片的车辆品牌及型号。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * top_num 返回预测得分top结果数,默认为5 + */ + Json::Value car_detect( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_car_detect, null, data, null); + + return result; + } + + /** + * logo_search + * 该请求用于检测和识别图片中的品牌LOGO信息。即对于输入的一张图片(可正常解码,且长宽比适宜),输出图片中LOGO的名称、位置和置信度。 当效果欠佳时,可以建立子库(请加入QQ群:649285136 联系工作人员申请建库)并自定义logo入库,提高识别效果。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * custom_lib 是否只使用自定义logo库的结果,默认false:返回自定义库+默认库的识别结果 + */ + Json::Value logo_search( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_logo_search, null, data, null); + + return result; + } + + /** + * logo_add + * 该接口尚在邀测阶段,使用该接口之前需要线下联系工作人员完成建库方可使用,请加入QQ群:649285136 联系相关人员。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * @param brief brief,检索时带回。此处要传对应的name与code字段,name长度小于100B,code长度小于150B + * options 可选参数: + */ + Json::Value logo_add( + std::string const &image, + std::string const &brief, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["brief"] = brief; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_logo_add, null, data, null); + + return result; + } + + /** + * logo_delete_by_image + * 该接口尚在邀测阶段,使用该接口之前需要线下联系工作人员完成建库方可使用,请加入QQ群:649285136 联系相关人员。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value logo_delete_by_image( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_logo_delete, null, data, null); + + return result; + } + + /** + * logo_delete_by_sign + * 该接口尚在邀测阶段,使用该接口之前需要线下联系工作人员完成建库方可使用,请加入QQ群:649285136 联系相关人员。 + * @param cont_sign 图片签名(和image二选一,image优先级更高) + * options 可选参数: + */ + Json::Value logo_delete_by_sign( + std::string const &cont_sign, + const std::map &options) { + std::map data; + + data["cont_sign"] = cont_sign; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_logo_delete, null, data, null); + + return result; + } + + /** + * animal_detect + * 该请求用于识别一张图片。即对于输入的一张图片(可正常解码,且长宽比适宜),输出动物识别结果 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * top_num 返回预测得分top结果数,默认为6 + */ + Json::Value animal_detect( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_animal_detect, null, data, null); + + return result; + } + + /** + * plant_detect + * 该请求用于识别一张图片。即对于输入的一张图片(可正常解码,且长宽比适宜),输出植物识别结果。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value plant_detect( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_plant_detect, null, data, null); + + return result; + } + + /** + * object_detect + * 用户向服务请求检测图像中的主体位置。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * with_face 如果检测主体是人,主体区域是否带上人脸部分,0-不带人脸区域,其他-带人脸区域,裁剪类需求推荐带人脸,检索/识别类需求推荐不带人脸。默认取1,带人脸。 + */ + Json::Value object_detect( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_object_detect, null, data, null); + + return result; + } + + + /** + * 图像多主体检测 + * 检测出图片中多个主体,并给出位置、标签和置信得分。 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value multobjectdetect( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_mult_object_detect, null, data, null); + + return result; + } + + /** + * 自定义菜单识别检索 + * 在已自建菜品库并入库的情况下,该接口实现单菜品/多菜品的识别。 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value dishsearch( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_dishSearch, null, data, null); + + return result; + } + + /** + * 果蔬识别 + * 该请求用于识别果蔬类食材,即对于输入的一张图片(可正常解码,且长宽比适宜),输出图片中的果蔬食材结果。 + * @param image 二进制图像数据 + * options 可选参数: + * topNum 返回预测得分top结果数,如果为空或小于等于0默认为5;如果大于20默认20 + */ + Json::Value ingredient( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_ingredient, null, data, null); + + return result; + } + + /** + * 自定义菜单识别删除 + * 在已自建菜品库并入库的情况下,该接口实现单菜品/多菜品的识别。 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value dishdeletebyimage( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_dishDelete, null, data, null); + + return result; + } + + /** + * 组合接口 + * 同时调用多个模型服务。支持图像识别的多个接口 + * @param image 二进制图像数据 + * @param scenes 本次调用的模型服务,数组表示 + * options 可选参数: + * sceneConf 对特定服务,支持的个性化参数,若不填则使用默认设置 + */ + Json::Value combination(std::string const &image, + Json::Value const &scenes, + Json::Value const &options) { + Json::Value data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["scenes"] = scenes; + merge_json(data, options); + + Json::Value result = this->request_com(_combination, data); + return result; + } + + /** + * 自定义菜单识别删除 + * 入库菜品图片的删除操作 + * @param contSign 图片签名 + * options 可选参数: + */ + Json::Value dishdeletebycontsign( + std::string const &contSign, + const std::map &options) { + std::map data; + + data["contSign"] = contSign; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_dishDelete, null, data, null); + + return result; + } + + /** + * 自定义菜单识别 + * 该接口实现单张菜品图片入库,入库时需要同步提交图片及可关联至本地菜品图库的摘要信息(具体变量为brief,brief可传入图片在本地标记id、图片url、图片名称等) + * @param image 二进制图像数据 + * @param brief 菜品名称摘要信息,检索时带回,不超过256B + * options 可选参数: + */ + Json::Value dishadd( + std::string const &image, + std::string const &brief, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["brief"] = brief; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_dishadd, null, data, null); + + return result; + } + + /**红酒识别 + * 在已自建菜品库并入库的情况下,该接口实现单菜品/多菜品的识别。 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value redwine( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_redwine, null, data, null); + + return result; + } + + /**红酒识别 + * 在已自建菜品库并入库的情况下,该接口实现单菜品/多菜品的识别。 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value redwineUrl( + std::string const &url, + const std::map &options) { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_redwine, null, data, null); + + return result; + } + + /**货币识别 + * 在已自建菜品库并入库的情况下,该接口实现单菜品/多菜品的识别。 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value currency( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_currency, null, data, null); + + return result; + } + + /**货币识别 + * 在已自建菜品库并入库的情况下,该接口实现单菜品/多菜品的识别。 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value currencyUrl( + std::string const &url, + const std::map &options) { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_currency, null, data, null); + + return result; + } + + /** + * 组合接口 + * 同时调用多个模型服务。支持图像识别的多个接口 + * @param image 二进制图像数据 + * @param scenes 本次调用的模型服务,数组表示 + * options 可选参数: + * sceneConf 对特定服务,支持的个性化参数,若不填则使用默认设置 + */ + Json::Value combinationUrl(std::string const &imgUrl, + Json::Value const &scenes, + Json::Value const &options) { + Json::Value data; + data["imgUrl"] = imgUrl; + data["scenes"] = scenes; + merge_json(data, options); + + Json::Value result = this->request_com(_combination, data); + return result; + } + + /** + * 车辆属性识别 + * 传入单帧图像,检测图片中所有车辆,返回每辆车的类型和坐标位置,可识别小汽车、卡车、巴士、摩托车、三轮车、自行车6大类车辆, + * @param image 二进制图像数据 + * options 可选参数: + * type 是否选定某些属性输出对应的信息,可从12种输出属性中任选若干,用英文逗号分隔(例如vehicle_type,roof_rack,skylight)。默认输出全部属性 + */ + Json::Value vehicleAttr( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_vehicle_attr, null, data, null); + return result; + } + + /** + * 车辆属性识别 + * 传入单帧图像,检测图片中所有车辆,返回每辆车的类型和坐标位置,可识别小汽车、卡车、巴士、摩托车、三轮车、自行车6大类车辆, + * @param url 图片完整URL + * options 可选参数: + * type 是否选定某些属性输出对应的信息,可从12种输出属性中任选若干,用英文逗号分隔(例如vehicle_type,roof_rack,skylight)。默认输出全部属性 + */ + Json::Value vehicleAttrUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_vehicle_attr, null, data, null); + return result; + } + + /** + * 车辆检测-高空版 + * 面向高空拍摄视角(30米以上),传入单帧图像,检测图片中所有车辆,返回每辆车的坐标位置(不区分车辆类型),并进行车辆计数,支持指定矩形区域的车辆检测与数量统计。 + * @param image 二进制图像数据 + * options 可选参数: + * area 只统计该矩形区域内的车辆数,缺省时为全图统计。逗号分隔,如‘x1,y1,x2,y2,x3,y3...xn,yn',按顺序依次给出每个顶点的x、y坐标(默认尾点和首点相连),形成闭合矩形区域。 + */ + Json::Value vehicleDetectHigh( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_vehicle_detect_high, null, data, null); + return result; + } + + /** + * 车辆检测-高空版 + * 面向高空拍摄视角(30米以上),传入单帧图像,检测图片中所有车辆,返回每辆车的坐标位置(不区分车辆类型),并进行车辆计数,支持指定矩形区域的车辆检测与数量统计。 + * @param url 图片完整URL + * options 可选参数: + * area 只统计该矩形区域内的车辆数,缺省时为全图统计。逗号分隔,如‘x1,y1,x2,y2,x3,y3...xn,yn',按顺序依次给出每个顶点的x、y坐标(默认尾点和首点相连),形成闭合矩形区域。 + */ + Json::Value vehicleDetectHighUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_vehicle_detect_high, null, data, null); + return result; + } + + /** + * 车型识别 + * 识别图片中车辆的具体车型,可识别常见的3000+款车型(小汽车为主),输出车辆的品牌型号、颜色、年份、位置信息;支持返回对应识别结果的百度百科词条信息,包含词条名称、百科页面链接、百科图片链接、百科内容简介。 + * @param image 二进制图像数据 + * options 可选参数: + * top_num 返回结果top n,默认5。e * baike_num 返回百科信息的结果数,默认不返回 + */ + Json::Value carDetect( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_car_detect, null, data, null); + return result; + } + + /** + * 车型识别 + * 识别图片中车辆的具体车型,可识别常见的3000+款车型(小汽车为主),输出车辆的品牌型号、颜色、年份、位置信息;支持返回对应识别结果的百度百科词条信息,包含词条名称、百科页面链接、百科图片链接、百科内容简介。 + * @param url 图片完整URL + * options 可选参数: + * top_num 返回结果top n,默认5。e * baike_num 返回百科信息的结果数,默认不返回 + */ + Json::Value carDetectUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_car_detect, null, data, null); + return result; + } + + /** + * 车辆检测 + * 入单帧图像,检测图片中所有机动车辆,返回每辆车的类型和坐标位置,可识别小汽车、卡车、巴士、摩托车、三轮车5类车辆,并对每类车辆分别计数,同时可定位小汽车、卡车、巴士的车牌位置,支持指定矩形区域的车辆检测与数量统计 + * @param image 二进制图像数据 + * options 可选参数: + * area 只统计该矩形区域内的车辆数,缺省时为全图统计。逗号分隔,如‘x1,y1,x2,y2,x3,y3...xn,yn',按顺序依次给出每个顶点的x、y坐标(默认尾点和首点相连),形成闭合矩形区域。 + */ + Json::Value vehicleDetect( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_vehicle_detect, null, data, null); + return result; + } + + /** + * 车辆检测 + * 入单帧图像,检测图片中所有机动车辆,返回每辆车的类型和坐标位置,可识别小汽车、卡车、巴士、摩托车、三轮车5类车辆,并对每类车辆分别计数,同时可定位小汽车、卡车、巴士的车牌位置,支持指定矩形区域的车辆检测与数量统计 + * @param url 图片完整URL + * options 可选参数: + * area 只统计该矩形区域内的车辆数,缺省时为全图统计。逗号分隔,如‘x1,y1,x2,y2,x3,y3...xn,yn',按顺序依次给出每个顶点的x、y坐标(默认尾点和首点相连),形成闭合矩形区域。 + */ + Json::Value vehicleDetectUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_vehicle_detect, null, data, null); + return result; + } + + /** + * 车辆分割 + * 传入单帧图像,检测图像中的车辆,以小汽车为主,识别车辆的轮廓范围,与背景进行分离,返回分割后的二值图、灰度图,支持多个车辆、车门打开、后备箱打开、机盖打开、正面、侧面、背面等各种拍摄场景。 + * @param image 二进制图像数据 + * options 可选参数: + * type 可以通过设置type参数,自主设置返回哪些结果图,避免造成带宽的浪费。1)可选值说明:labelmap - 二值图像,需二次处理方能查看分割效果scoremap - 车辆前景灰度图2)type 参数值可以是可选值的组合,用逗号分隔;如果无此参数默认输出全部3类结果图 + */ + Json::Value vehicleSeg( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_vehicle_seg, null, data, null); + return result; + } + + /** + * 车辆外观损伤识别 + * 针对常见的小汽车车型,识别车辆外观受损部件及损伤类型,支持32种车辆部件、5大类外观损伤。同时可输出损伤的数值化结果(长宽、面积、部件占比),支持单图多种损伤的识别。 + * @param image 二进制图像数据 + * options 可选参数: + + */ + Json::Value vehicleDamage( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_vehicle_damage, null, data, null); + return result; + } + + /** + * 车流统计 + * 根据传入的连续视频图片序列,进行车辆检测和追踪,返回每个车辆的坐标位置、车辆类型(包括小汽车、卡车、巴士、摩托车、三轮车5类)。 + * 在原图中指定区域,根据车辆轨迹判断驶入/驶出区域的行为,统计各类车辆的区域进出车流量,可返回含统计值和跟踪框的渲染图。 + * @param image 二进制图像数据 * @param case_id 任务ID(通过case_id区分不同视频流,自拟,不同序列间不可重复) + * @param case_init 每个case的初始化信号,为true时对该case下的跟踪算法进行初始化,为false时重载该case的跟踪状态。 + * 当为false且读取不到相应case的信息时,直接重新初始化 + * @param area 只统计进出该区域的车辆。逗号分隔,如‘x1,y1,x2,y2,x3,y3...xn,yn',按顺序依次给出每个顶点的x、y坐标 + * (默认尾点和首点相连),形成闭合多边形区域。 + * options 可选参数: + * show 是否返回结果图(含统计值和跟踪框)。选true时返回渲染后的图片(base64),其它无效值或为空则默认false。 + */ + Json::Value trafficFlow( + std::string image, + int case_id, + std::string case_init, + std::string area, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["case_id"] = case_id; + data["case_init"] = case_init; + data["area"] = area; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_traffic_flow, null, data, null); + return result; + } + + /** + * 车流统计 + * 根据传入的连续视频图片序列,进行车辆检测和追踪,返回每个车辆的坐标位置、车辆类型(包括小汽车、卡车、巴士、摩托车、三轮车5类)。在原图中指定区域,根据车辆轨迹判断驶入/驶出区域的行为,统计各类车辆的区域进出车流量,可返回含统计值和跟踪框的渲染图。 + * @param url 图片完整URL * @param case_id 任务ID(通过case_id区分不同视频流,自拟,不同序列间不可重复) * @param case_init 每个case的初始化信号,为true时对该case下的跟踪算法进行初始化,为false时重载该case的跟踪状态。当为false且读取不到相应case的信息时,直接重新初始化 * @param area 只统计进出该区域的车辆。逗号分隔,如‘x1,y1,x2,y2,x3,y3...xn,yn',按顺序依次给出每个顶点的x、y坐标(默认尾点和首点相连),形成闭合多边形区域。 + * options 可选参数: + * show 是否返回结果图(含统计值和跟踪框)。选true时返回渲染后的图片(base64),其它无效值或为空则默认false。 + */ + Json::Value trafficFlowUrl( + std::string url, + int case_id, + std::string case_init, + std::string area, + std::map options) { + std::map data; + data["url"] = url; + data["case_id"] = case_id; + data["case_init"] = case_init; + data["area"] = area; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_traffic_flow, null, data, null); + return result; + } + + + /** + * 通用物体识别 + * 该请求用于通用物体及场景识别,即对于输入的一张图片(可正常解码,且长宽比适宜),输出图片中的多个物体及场景标签。 + * @param image 二进制图像数据 + * options 可选参数: + * baike_num 返回百科信息的结果数,默认不返回 + */ + Json::Value advancedGeneral( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_advanced_general, null, data, null); + return result; + } + + /** + * 通用物体识别 + * 该请求用于通用物体及场景识别,即对于输入的一张图片(可正常解码,且长宽比适宜),输出图片中的多个物体及场景标签。 + * @param url 图片完整URL + * options 可选参数: + * baike_num 返回百科信息的结果数,默认不返回 + */ + Json::Value advancedGeneralUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_advanced_general, null, data, null); + return result; + } + + /** + * 地标识别 + * @param image 二进制图像数据 + * options 可选参数: + * baike_num 返回百科信息的结果数,默认不返回 + */ + Json::Value landmark_v1( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_landmark_v1, null, data, null); + return result; + } + + /** + * 地标识别 + * @param url 图片完整URL + * options 可选参数: + */ + Json::Value landmark_v1_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_landmark_v1, null, data, null); + return result; + } + + + /** + * 自定义红酒识别--入库 + * 接口使用文档: https://ai.baidu.com/ai-doc/IMAGERECOGNITION/skh4k58o4#%E8%87%AA%E5%AE%9A%E4%B9%89%E7%BA%A2%E9%85%92-%E5%85%A5%E5%BA%93 + */ + Json::Value redwine_add_v1_image(std::string const &image, std::string const &brief, + const std::map &options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["brief"] = brief; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = this->request(_redwine_add_v1, null, data, null); + return result; + } + + /** + * 自定义红酒识别--入库 + * 接口使用文档: https://ai.baidu.com/ai-doc/IMAGERECOGNITION/skh4k58o4#%E8%87%AA%E5%AE%9A%E4%B9%89%E7%BA%A2%E9%85%92-%E5%85%A5%E5%BA%93 + */ + Json::Value redwine_add_v1_url(std::string const &url, std::string const &brief, + const std::map &options) { + std::map data; + data["url"] = url; + data["brief"] = brief; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = this->request(_redwine_add_v1, null, data, null); + return result; + } + + /** + * 自定义红酒识别--检索 + * 接口使用文档: https://ai.baidu.com/ai-doc/IMAGERECOGNITION/skh4k58o4#%E8%87%AA%E5%AE%9A%E4%B9%89%E7%BA%A2%E9%85%92-%E6%A3%80%E7%B4%A2 + */ + Json::Value redwine_search_v1_image(std::string const &image, std::string const &custom_lib, + const std::map &options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["custom_lib"] = custom_lib; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = this->request(_redwine_search_v1, null, data, null); + return result; + } + + /** + * 自定义红酒识别--检索 + * 接口使用文档: https://ai.baidu.com/ai-doc/IMAGERECOGNITION/skh4k58o4#%E8%87%AA%E5%AE%9A%E4%B9%89%E7%BA%A2%E9%85%92-%E6%A3%80%E7%B4%A2 + */ + Json::Value redwine_search_v1_url(std::string const &url, std::string const &custom_lib, + const std::map &options) { + std::map data; + data["url"] = url; + data["custom_lib"] = custom_lib; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = this->request(_redwine_search_v1, null, data, null); + return result; + } + + /** + * 自定义红酒识别--删除 + * 接口使用文档: https://ai.baidu.com/ai-doc/IMAGERECOGNITION/skh4k58o4#%E8%87%AA%E5%AE%9A%E4%B9%89%E7%BA%A2%E9%85%92-%E5%88%A0%E9%99%A4 + */ + Json::Value redwine_delete_v1_image(std::string const &image, + const std::map &options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = this->request(_redwine_delete_v1, null, data, null); + return result; + } + + /** + * 自定义红酒识别--删除 + * 接口使用文档: https://ai.baidu.com/ai-doc/IMAGERECOGNITION/skh4k58o4#%E8%87%AA%E5%AE%9A%E4%B9%89%E7%BA%A2%E9%85%92-%E5%88%A0%E9%99%A4 + */ + Json::Value redwine_delete_v1_sign(std::string const &sign, + const std::map &options) { + std::map data; + data["cont_sign_list"] = sign; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = this->request(_redwine_delete_v1, null, data, null); + return result; + } + + /** + * 自定义红酒识别--更新 + * 接口使用文档: https://ai.baidu.com/ai-doc/IMAGERECOGNITION/skh4k58o4#%E8%87%AA%E5%AE%9A%E4%B9%89%E7%BA%A2%E9%85%92%E6%9B%B4%E6%96%B0 + */ + Json::Value redwine_update_v1_image(std::string const &image, std::string const &brief, + const std::map &options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["brief"] = brief; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = this->request(_redwine_update_v1, null, data, null); + return result; + } + + /** + * 自定义红酒识别--更新 + * 接口使用文档: https://ai.baidu.com/ai-doc/IMAGERECOGNITION/skh4k58o4#%E8%87%AA%E5%AE%9A%E4%B9%89%E7%BA%A2%E9%85%92%E6%9B%B4%E6%96%B0 + */ + Json::Value redwine_update_v1_url(std::string const &url, std::string const &brief, + const std::map &options) { + std::map data; + data["url"] = url; + data["brief"] = brief; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = this->request(_redwine_update_v1, null, data, null); + return result; + } + + /** + * 车辆属性识别 + * 接口使用文档: https://ai.baidu.com/ai-doc/VEHICLE/mk3hb3fde + */ + Json::Value vehicle_attr_classify_v2_image(std::string const &image, + const std::map &options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = this->request(_vehicle_attr_classify_v2, null, data, null); + return result; + } + + /** + * 车辆属性识别 + * 接口使用文档: https://ai.baidu.com/ai-doc/VEHICLE/mk3hb3fde + */ + Json::Value vehicle_attr_classify_v2_url(std::string const &url, + const std::map &options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = this->request(_vehicle_attr_classify_v2, null, data, null); + return result; + } + }; +} +#endif diff --git a/third/include/aip-cpp-sdk/image_process.h b/third/include/aip-cpp-sdk/image_process.h new file mode 100644 index 0000000..d249353 --- /dev/null +++ b/third/include/aip-cpp-sdk/image_process.h @@ -0,0 +1,693 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ + +#ifndef __AIP_IMAGEPROCESS_H__ +#define __AIP_IMAGEPROCESS_H__ + +#include "base/base.h" + +namespace aip { + + class Imageprocess: public AipBase + { + + public: + std::string _image_definition_enhance = + "https://aip.baidubce.com/rest/2.0/image-process/v1/image_definition_enhance"; + + std::string _sky_seg = + "https://aip.baidubce.com/rest/2.0/image-process/v1/sky_seg"; + + std::string _image_tyle_trans = + "https://aip.baidubce.com/rest/2.0/image-process/v1/style_trans"; + + std::string _selfie_anime = + "https://aip.baidubce.com/rest/2.0/image-process/v1/selfie_anime"; + + std::string _color_enhance = + "https://aip.baidubce.com/rest/2.0/image-process/v1/color_enhance"; + + std::string _image_inpainting = + "https://aip.baidubce.com/rest/2.0/image-process/v1/inpainting"; + + std::string _image_quality_enhance_v1 = + "https://aip.baidubce.com/rest/2.0/image-process/v1/image_quality_enhance"; + + std::string _contrast_enhance_v1 = + "https://aip.baidubce.com/rest/2.0/image-process/v1/contrast_enhance"; + + std::string _dehaze_v1 = + "https://aip.baidubce.com/rest/2.0/image-process/v1/dehaze"; + + std::string _colourize_v1 = + "https://aip.baidubce.com/rest/2.0/image-process/v1/colourize"; + + std::string _stretch_restore_v1 = + "https://aip.baidubce.com/rest/2.0/image-process/v1/stretch_restore"; + + std::string _remove_moire_v1 = "https://aip.baidubce.com/rest/2.0/image-process/v1/remove_moire"; + std::string _customize_stylization_v1 = + "https://aip.baidubce.com/rest/2.0/image-process/v1/customize_stylization"; + std::string _doc_repair_v1 = "https://aip.baidubce.com/rest/2.0/image-process/v1/doc_repair"; + std::string _denoise_v1 = "https://aip.baidubce.com/rest/2.0/image-process/v1/denoise"; + + Imageprocess(const std::string & app_id, const std::string & ak, const std::string & sk): + AipBase(app_id, ak, sk) + { + } + + /** + * 图像修复 + * 去除图片中不需要的遮挡物,并用背景内容填充,提高图像质量。 + * @param image 二进制图像数据 + * @param rectangle 要去除的位置为规则矩形时,给出坐标信息.每个元素包含left, top, width, height,int 类型 + * options 可选参数: + */ + Json::Value imageinpainting( + std::string const & image, + Json::Value & rectangle, + std::map options) + { + Json::Value data; + std::string access_token = this->getAccessToken(); + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["rectangle"] = rectangle; + + std::map< std::string,std::string >::iterator it ; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + std::string mid = "?access_token="; + std::string url = _image_inpainting+mid+access_token; + Json::Value result = + this->request_com(url, data); + + return result; + } + + /** + * 图像修复 + * 去除图片中不需要的遮挡物,并用背景内容填充,提高图像质量。 + * @param image 二进制图像数据 + * @param rectangle 要去除的位置为规则矩形时,给出坐标信息.每个元素包含left, top, width, height,int 类型 + * options 可选参数: + */ + Json::Value imageinpainting_url( + std::string const & url, + Json::Value & rectangle, + std::map options) + { + Json::Value data; + data["url"] = url; + data["rectangle"] = rectangle; + + std::map< std::string,std::string >::iterator it ; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + Json::Value result = + this->request_com(_image_inpainting, data); + + return result; + } + + /** + * 图像色彩增强 + * 可智能调节图片的色彩饱和度、亮度、对比度,使得图片内容细节、色彩更加逼真,可用于提升网站图片、手机相册图片、视频封面图片的质量 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value colorenhance( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_color_enhance, null, data, null); + + return result; + } + + /** + * 人像动漫化接口 + * 运用世界领先的对抗生成网络,结合人脸检测、头发分割、人像分割等技术,为用户量身定制千人千面的二次元动漫形象,并且可通过参数设置,生成戴口罩的二次元动漫人像 + * @param image 二进制图像数据 + * options 可选参数: + * type anime或者anime_mask。前者生成二次元动漫图,后者生成戴口罩的二次元动漫人像 + * mask_id 在type参数填入anime_mask时生效,1~8之间的整数,用于指定所使用的口罩的编码。type参数没有填入anime_mask,或mask_id 为空时,生成不戴口罩的二次元动漫图。 + */ + Json::Value selfieanime( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_selfie_anime, null, data, null); + + return result; + } + + /** + * 图像风格转换 + * 可将图像转化成卡通画、铅笔画、彩色铅笔画,或者哥特油画、彩色糖块油画、呐喊油画、神奈川冲浪里油画、奇异油画、薰衣草油画等共计9种风格,可用于开展趣味活动,或集成到美图应用中对图像进行风格转换 + * @param image 二进制图像数据 + * @param option 转换的风格 + * options 可选参数: + */ + Json::Value imagestyletrans( + std::string const & image, + std::string const & option, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["option"] = option; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_image_tyle_trans, null, data, null); + + return result; + } + + /** + * 天空分割 + * 可智能分割出天空边界位置,输出天空和其余背景的灰度图和二值图,可用于图像二次处理,进行天空替换、抠图等图片编辑场景。 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value skyseg( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_sky_seg, null, data, null); + + return result; + } + + /** + * 图像清晰增强 + * 对压缩后的模糊图像实现智能快速去噪,优化图像纹理细节,使画面更加自然清晰 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value imagedefinitionenhance( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_image_definition_enhance, null, data, null); + + return result; + } + + /** + * 图像风格转换 + * 图像风格转换 + * @param url 图片完整url + * @param option 转换的风格 + * options 可选参数: + */ + Json::Value imagestyletransurl( + std::string const & url, + std::string const & option, + const std::map & options) + { + std::map data; + + data["url"] = url; + data["option"] = option; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_image_tyle_trans, null, data, null); + + return result; + } + + /** + * 图像色彩增强 + * @param url 图片完整url + * options 可选参数: + */ + Json::Value colorenhanceurl( + std::string const & url, + const std::map & options) + { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_color_enhance, null, data, null); + + return result; + } + + /** + * 人像动漫化接口 + * 人像动漫化接口 + * @param url 图片完整url + * options 可选参数: + * type anime或者anime_mask。前者生成二次元动漫图,后者生成戴口罩的二次元动漫人像 + * mask_id 在type参数填入anime_mask时生效,1~8之间的整数,用于指定所使用的口罩的编码。type参数没有填入anime_mask,或mask_id 为空时,生成不戴口罩的二次元动漫图。 + */ + Json::Value selfieanimeurl( + std::string const & url, + const std::map & options) + { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_selfie_anime, null, data, null); + + return result; + } + + /** + * 天空分割 + * @param url 图片完整url + * options 可选参数: + */ + Json::Value skysegurl( + std::string const & url, + const std::map & options) + { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_sky_seg, null, data, null); + + return result; + } + + /** + * 图像清晰增强 + * @param url 图片完整url + * options 可选参数: + */ + Json::Value imagedefinitionenhanceurl( + std::string const & url, + const std::map & options) + { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_image_definition_enhance, null, data, null); + + return result; + } + + /** + * 图像无损放大 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value image_quality_enhance_v1( + std::string const & image, + const std::map & options) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_image_quality_enhance_v1, null, data, null); + + return result; + } + + /** + * 图像无损放大 + * @param url 图片完整url + * options 可选参数: + */ + Json::Value image_quality_enhance_v1_url( + std::string const & url, + const std::map & options) + { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_image_quality_enhance_v1, null, data, null); + + return result; + } + + /** + * 图像对比度增强 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value contrast_enhance_v1( + std::string const & image, + const std::map & options) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_contrast_enhance_v1, null, data, null); + + return result; + } + + /** + * 图像对比度增强 + * @param url 图片完整url + * options 可选参数: + */ + Json::Value contrast_enhance_v1_url( + std::string const & url, + const std::map & options) + { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_contrast_enhance_v1, null, data, null); + + return result; + } + + /** + * 图像去雾 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value dehaze_v1( + std::string const & image, + const std::map & options) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_dehaze_v1, null, data, null); + + return result; + } + + /** + * 图像去雾 + * @param url 图片完整url + * options 可选参数: + */ + Json::Value dehaze_v1_url( + std::string const & url, + const std::map & options) + { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_dehaze_v1, null, data, null); + + return result; + } + + /** + * 黑白图像上色 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value colourize_v1( + std::string const & image, + const std::map & options) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_colourize_v1, null, data, null); + + return result; + } + + /** + * 黑白图像上色 + * @param url 图片完整url + * options 可选参数: + */ + Json::Value colourize_v1_url( + std::string const & url, + const std::map & options) + { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_colourize_v1, null, data, null); + + return result; + } + + /** + * 拉伸图像恢复 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value stretch_restore_v1( + std::string const & image, + const std::map & options) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_stretch_restore_v1, null, data, null); + + return result; + } + + /** + * 拉伸图像恢复 + * @param url 图片完整url + * options 可选参数: + */ + Json::Value stretch_restore_v1_url( + std::string const & url, + const std::map & options) + { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_stretch_restore_v1, null, data, null); + + return result; + } + + /** + * 图片去摩尔纹 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/IMAGEPROCESS/ql4wdlnc0 + */ + Json::Value remove_moire_v1( + std::string const & image, + const std::map &options) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_remove_moire_v1, null, data, null); + + return result; + } + + /** + * 图片去摩尔纹 - url + * 接口使用文档链接: https://ai.baidu.com/ai-doc/IMAGEPROCESS/ql4wdlnc0 + */ + Json::Value remove_moire_v1_url( + std::string const & url, + const std::map &options) + { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_remove_moire_v1, null, data, null); + + return result; + } + + /** + * 图片去摩尔纹 - pdf + * 接口使用文档链接: https://ai.baidu.com/ai-doc/IMAGEPROCESS/ql4wdlnc0 + */ + Json::Value remove_moire_v1_pdf( + std::string const & pdf, + const std::map &options) + { + std::map data; + data["pdf_file"] = base64_encode(pdf.c_str(), (int) pdf.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_remove_moire_v1, null, data, null); + + return result; + } + + + /** + * 图像风格自定义 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/IMAGEPROCESS/al50vf6bq + */ + Json::Value customize_stylization_v1(std::string const & image, Json::Value & options) + { + Json::Value data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_customize_stylization_v1, data, &headers); + + return result; + } + + /** + * 图像风格自定义 - url + * 接口使用文档链接: https://ai.baidu.com/ai-doc/IMAGEPROCESS/al50vf6bq + */ + Json::Value customize_stylization_v1_url(std::string const & url, Json::Value & options) + { + Json::Value data; + data["url"] = url; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_customize_stylization_v1, data, &headers); + + return result; + } + + /** + * 文档图片去底纹 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/IMAGEPROCESS/Nl6os53ab + */ + Json::Value doc_repair_v1( + std::string const & image, + const std::map &options) + { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_doc_repair_v1, null, data, null); + + return result; + } + + /** + * 文档图片去底纹 - url + * 接口使用文档链接: https://ai.baidu.com/ai-doc/IMAGEPROCESS/Nl6os53ab + */ + Json::Value doc_repair_v1_url( + std::string const &url, + const std::map &options) + { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_doc_repair_v1, null, data, null); + + return result; + } + + /** + * 图像去噪 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/IMAGEPROCESS/Tl78sby7g + */ + Json::Value denoise_v1( + std::string const & image, int option) + { + Json::Value data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["option"] = option; + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_denoise_v1, data, &headers); + + return result; + } + + /** + * 图像去噪 - url + * 接口使用文档链接: https://ai.baidu.com/ai-doc/IMAGEPROCESS/Tl78sby7g + */ + Json::Value denoise_v1_url( + std::string const &url, int option) + { + Json::Value data; + data["url"] = url; + data["option"] = option; + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_denoise_v1, data, &headers); + + return result; + } + + }; +} +#endif diff --git a/third/include/aip-cpp-sdk/image_search.h b/third/include/aip-cpp-sdk/image_search.h new file mode 100644 index 0000000..f1749c2 --- /dev/null +++ b/third/include/aip-cpp-sdk/image_search.h @@ -0,0 +1,1001 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ + +#ifndef __AIP_IMAGESEARCH_H__ +#define __AIP_IMAGESEARCH_H__ + +#include "base/base.h" + +namespace aip { + + class Imagesearch: public AipBase + { + public: + + std::string _picturebook_search = + "https://aip.baidubce.com/rest/2.0/imagesearch/v1/realtime_search/picturebook/search"; + + std::string _picturebook_update = + "https://aip.baidubce.com/rest/2.0/imagesearch/v1/realtime_search/picturebook/update"; + + std::string _picturebook_add = + "https://aip.baidubce.com/rest/2.0/imagesearch/v1/realtime_search/picturebook/add"; + + std::string _picturebook_delete = + "https://aip.baidubce.com/rest/2.0/imagesearch/v1/realtime_search/picturebook/delete"; + + std::string _same_hq_add = + "https://aip.baidubce.com/rest/2.0/realtime_search/same_hq/add"; + + std::string _same_hq_search = + "https://aip.baidubce.com/rest/2.0/realtime_search/same_hq/search"; + + std::string _same_hq_delete = + "https://aip.baidubce.com/rest/2.0/realtime_search/same_hq/delete"; + + std::string _same_hq_update = + "https://aip.baidubce.com/rest/2.0/realtime_search/same_hq/update"; + + std::string _similar_add = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/add"; + + std::string _similar_search = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/search"; + + std::string _similar_delete = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/delete"; + + std::string _similar_update = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/update"; + + std::string _product_add = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/product/add"; + + std::string _product_search = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/product/search"; + + std::string _product_delete = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/product/delete"; + + std::string _product_update = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/product/update"; + + std::string _materiel_add = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/materiel/add"; + + std::string _materiel_search = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/materiel/search"; + + std::string _materiel_delete = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/materiel/delete"; + + std::string _materiel_update = + "https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/materiel/update"; + + Imagesearch(const std::string & app_id, const std::string & ak, const std::string & sk): AipBase(app_id, ak, sk) + { + } + + /** + * 面料入库 + * materiel_add + * 文档参考:https://ai.baidu.com/ai-doc/IMAGESEARCH/kl6xkl6kq#%E9%9D%A2%E6%96%99%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E5%85%A5%E5%BA%93 + */ + Json::Value materiel_add( + std::string const & image, + std::string const & brief, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["brief"] = brief; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_materiel_add, null, data, null); + + return result; + } + + /** + * 面料入库 + * materiel_add + * 文档参考:https://ai.baidu.com/ai-doc/IMAGESEARCH/kl6xkl6kq#%E9%9D%A2%E6%96%99%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E5%85%A5%E5%BA%93 + */ + Json::Value materiel_add_url( + std::string const & url, + std::string const & brief, + const std::map & options) + { + std::map data; + + data["url"] = url; + data["brief"] = brief; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_materiel_add, null, data, null); + + return result; + } + + /** + * 面料检索 + * materiel_search + * 文档参考:https://ai.baidu.com/ai-doc/IMAGESEARCH/kl6xkl6kq#%E9%9D%A2%E6%96%99%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%A3%80%E7%B4%A2 + */ + Json::Value materiel_search( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_materiel_search, null, data, null); + + return result; + } + + /** + * 面料检索 + * materiel_search + * 文档参考:https://ai.baidu.com/ai-doc/IMAGESEARCH/kl6xkl6kq#%E9%9D%A2%E6%96%99%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%A3%80%E7%B4%A2 + */ + Json::Value materiel_search_url( + std::string const & url, + const std::map & options) + { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_materiel_search, null, data, null); + + return result; + } + + /** + * 面料删除 + * materiel_delete_by_image + * 参考文档:https://ai.baidu.com/ai-doc/IMAGESEARCH/kl6xkl6kq#%E9%9D%A2%E6%96%99%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E5%88%A0%E9%99%A4 + */ + Json::Value materiel_delete_by_image( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_materiel_delete, null, data, null); + + return result; + } + + /** + * 面料删除 + * materiel_delete_by_url + * 参考文档:https://ai.baidu.com/ai-doc/IMAGESEARCH/kl6xkl6kq#%E9%9D%A2%E6%96%99%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E5%88%A0%E9%99%A4 + */ + Json::Value materiel_delete_by_url( + std::string const & url, + const std::map & options) + { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_materiel_delete, null, data, null); + + return result; + } + + /** + * 面料图片删除 + * materiel_delete_by_sign + * 参考文档:https://ai.baidu.com/ai-doc/IMAGESEARCH/kl6xkl6kq#%E9%9D%A2%E6%96%99%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E5%88%A0%E9%99%A4 + */ + Json::Value materiel_delete_by_sign( + std::string const & cont_sign, + const std::map & options) + { + std::map data; + + data["cont_sign"] = cont_sign; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_materiel_delete, null, data, null); + + return result; + } + + /** + * 面料更新 + * 参考文档:https://ai.baidu.com/ai-doc/IMAGESEARCH/kl6xkl6kq#%E9%9D%A2%E6%96%99%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%9B%B4%E6%96%B0 + */ + Json::Value materiel_update( + std::string const & image, + std::string const & brief, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["brief"] = brief; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_materiel_update, null, data, null); + + return result; + } + + /** + * 面料更新 + * 参考文档:https://ai.baidu.com/ai-doc/IMAGESEARCH/kl6xkl6kq#%E9%9D%A2%E6%96%99%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%9B%B4%E6%96%B0 + */ + Json::Value materiel_update_url( + std::string const & url, + std::string const & brief, + const std::map & options) + { + std::map data; + + data["url"] = url; + data["brief"] = brief; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_materiel_update, null, data, null); + + return result; + } + + /** + * 面料更新 + * 参考文档:https://ai.baidu.com/ai-doc/IMAGESEARCH/kl6xkl6kq#%E9%9D%A2%E6%96%99%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%9B%B4%E6%96%B0 + */ + Json::Value materiel_update_cont_sign( + std::string const & cont_sign, + std::string const & brief, + const std::map & options) + { + std::map data; + + data["cont_sign"] = cont_sign; + data["brief"] = brief; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_materiel_update, null, data, null); + + return result; + } + + /** + * same_hq_add + * 该请求用于实时检索相同图片集合。即对于输入的一张图片(可正常解码,且长宽比适宜),返回自建图库中相同的图片集合。相同图检索包含入库、检索、删除三个子接口;**在正式使用之前请加入QQ群:649285136 联系工作人员完成建库并调用入库接口完成图片入库**。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * brief 检索时原样带回,最长256B。 + */ + Json::Value same_hq_add( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_same_hq_add, null, data, null); + + return result; + } + + /** + * same_hq_search + * 使用该接口前,请加入QQ群:649285136 ,联系工作人员完成建库。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value same_hq_search( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_same_hq_search, null, data, null); + + return result; + } + + /** + * same_hq_delete_by_image + * 删除相同图 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value same_hq_delete_by_image( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_same_hq_delete, null, data, null); + + return result; + } + + /** + * same_hq_delete_by_sign + * 删除相同图 + * @param cont_sign 图片签名(和image二选一,image优先级更高) + * options 可选参数: + */ + Json::Value same_hq_delete_by_sign( + std::string const & cont_sign, + const std::map & options) + { + std::map data; + + data["cont_sign"] = cont_sign; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_same_hq_delete, null, data, null); + + return result; + } + + /** + * 更新相同图 + * 使用文档链接: https://ai.baidu.com/ai-doc/IMAGESEARCH/Ck3bczreq#%E7%9B%B8%E5%90%8C%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%9B%B4%E6%96%B0 + */ + Json::Value same_hq_update_by_image( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_same_hq_update, null, data, null); + + return result; + } + + /** + * 更新相同图 + * 使用文档链接: https://ai.baidu.com/ai-doc/IMAGESEARCH/Ck3bczreq#%E7%9B%B8%E5%90%8C%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%9B%B4%E6%96%B0 + */ + Json::Value same_hq_update_by_url( + std::string const & url, + const std::map & options) + { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_same_hq_update, null, data, null); + + return result; + } + + /** + * 更新相同图 + * 使用文档链接: https://ai.baidu.com/ai-doc/IMAGESEARCH/Ck3bczreq#%E7%9B%B8%E5%90%8C%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%9B%B4%E6%96%B0 + */ + Json::Value same_hq_update_by_sign( + std::string const & cont_sign, + const std::map & options) + { + std::map data; + + data["cont_sign"] = cont_sign; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_same_hq_update, null, data, null); + + return result; + } + + /** + * similar_add + * 该请求用于实时检索相似图片集合。即对于输入的一张图片(可正常解码,且长宽比适宜),返回自建图库中相似的图片集合。相似图检索包含入库、检索、删除三个子接口;**在正式使用之前请加入QQ群:649285136 联系工作人员完成建库并调用入库接口完成图片入库**。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * brief 检索时原样带回,最长256B。 + */ + Json::Value similar_add( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_similar_add, null, data, null); + + return result; + } + + /** + * similar_search + * 使用该接口前,请加入QQ群:649285136 ,联系工作人员完成建库。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value similar_search( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_similar_search, null, data, null); + + return result; + } + + /** + * similar_delete_by_image + * 删除相似图 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value similar_delete_by_image( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_similar_delete, null, data, null); + + return result; + } + + /** + * similar_delete_by_sign + * 删除相似图 + * @param cont_sign 图片签名(和image二选一,image优先级更高) + * options 可选参数: + */ + Json::Value similar_delete_by_sign( + std::string const & cont_sign, + const std::map & options) + { + std::map data; + + data["cont_sign"] = cont_sign; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_similar_delete, null, data, null); + + return result; + } + + /** + * 更新相似图 + * 使用文档链接: https://ai.baidu.com/ai-doc/IMAGESEARCH/3k3bczqz8#%E7%9B%B8%E4%BC%BC%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%9B%B4%E6%96%B0 + */ + Json::Value similar_update_by_image( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_similar_update, null, data, null); + + return result; + } + + /** + * 更新相似图 + * 使用文档链接: https://ai.baidu.com/ai-doc/IMAGESEARCH/3k3bczqz8#%E7%9B%B8%E4%BC%BC%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%9B%B4%E6%96%B0 + */ + Json::Value similar_update_by_url( + std::string const & url, + const std::map & options) + { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_similar_update, null, data, null); + + return result; + } + + /** + * 更新相似图 + * 使用文档链接: https://ai.baidu.com/ai-doc/IMAGESEARCH/3k3bczqz8#%E7%9B%B8%E4%BC%BC%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%9B%B4%E6%96%B0 + */ + Json::Value similar_update_by_sign( + std::string const & cont_sign, + const std::map & options) + { + std::map data; + + data["cont_sign"] = cont_sign; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_similar_update, null, data, null); + + return result; + } + + /** + * product_add + * 1、该请求用于实时检索商品类型图片相同或相似的图片集合,适用于电商平台或商品展示等场景,即对于输入的一张图片(可正常解码,且长宽比适宜),返回自建商品库中相同或相似的图片集合。 +2、商品检索包含入库、检索、删除三个子接口;**在正式使用之前请在[控制台](https://console.bce.baidu.com/ai/#/ai/imagesearch/overview/index "控制台")创建应用后,在应用详情页申请建库,建库成功后方可正常使用入库、检索、删除三个接口**。 + + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * brief 检索时原样带回,最长256B。**请注意,检索接口不返回原图,仅反馈当前填写的brief信息,所以调用该入库接口时,brief信息请尽量填写可关联至本地图库的图片id或者图片url、图片名称等信息** + * class_id1 商品分类维度1,支持1-60范围内的整数。检索时可圈定该分类维度进行检索 + * class_id2 商品分类维度1,支持1-60范围内的整数。检索时可圈定该分类维度进行检索 + */ + Json::Value product_add( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_product_add, null, data, null); + + return result; + } + + /** + * product_search + * 完成入库后,可使用该接口实现商品检索。 +**请注意,检索接口不返回原图,仅反馈当前填写的brief信息,请调用入库接口时尽量填写可关联至本地图库的图片id或者图片url等信息** + + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * class_id1 商品分类维度1,支持1-60范围内的整数。检索时可圈定该分类维度进行检索 + * class_id2 商品分类维度1,支持1-60范围内的整数。检索时可圈定该分类维度进行检索 + */ + Json::Value product_search( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_product_search, null, data, null); + + return result; + } + + /** + * product_delete_by_image + * 删除商品 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value product_delete_by_image( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_product_delete, null, data, null); + + return result; + } + + /** + * product_delete_by_sign + * 删除商品 + * @param cont_sign 图片签名(和image二选一,image优先级更高) + * options 可选参数: + */ + Json::Value product_delete_by_sign( + std::string const & cont_sign, + const std::map & options) + { + std::map data; + + data["cont_sign"] = cont_sign; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_product_delete, null, data, null); + + return result; + } + + /** + * 更新商品搜索 + * 使用文档链接: https://ai.baidu.com/ai-doc/IMAGESEARCH/Dk3bczrmj#%E5%95%86%E5%93%81%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%9B%B4%E6%96%B0 + */ + Json::Value product_update_by_image( + std::string const & image, + const Json::Value & options) + { + Json::Value data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_product_update, data, &headers); + + return result; + } + + /** + * 更新商品搜索 + * 使用文档链接: https://ai.baidu.com/ai-doc/IMAGESEARCH/Dk3bczrmj#%E5%95%86%E5%93%81%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%9B%B4%E6%96%B0 + */ + Json::Value product_update_by_url( + std::string const & url, + const Json::Value & options) + { + Json::Value data; + data["url"] = url; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_product_update, data, &headers); + + return result; + } + + /** + * 更新商品搜索 + * 使用文档链接: https://ai.baidu.com/ai-doc/IMAGESEARCH/Dk3bczrmj#%E5%95%86%E5%93%81%E5%9B%BE%E7%89%87%E6%90%9C%E7%B4%A2%E6%9B%B4%E6%96%B0 + */ + Json::Value product_update_by_sign( + std::string const & cont_sign, + const Json::Value & options) + { + Json::Value data; + data["cont_sign"] = cont_sign; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_product_update, data, &headers); + + return result; + } + + /** + * 绘本入库 + * 该接口实现单张图片入库,入库时需要同步提交图片及可关联至本地图库的摘要信息 + * @param url 图像url + * @param brief 检索时原样带回 + * options 可选参数: + * tags tag间以逗号分隔,最多2个tag,2个tag无层级关系,检索时支持逻辑运算 + */ + Json::Value picturebook_add_url( + std::string const & url, + std::string const & brief, + const std::map & options) + { + std::map data; + + data["url"] = url; + data["brief"] = brief; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_picturebook_add, null, data, null); + + return result; + } + + /** + * 绘本更新 + * 绘本图片更新 + * @param image 二进制图像数据 + * options 可选参数: + * tags tag间以逗号分隔,最多2个tag,2个tag无层级关系,检索时支持逻辑运算 + * brief 更新的摘要信息,最长256B。样例:{"name" + */ + Json::Value picturebook_update( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_picturebook_update, null, data, null); + + return result; + } + + /** + * 绘本图片删除 + * 删除图库中的图片 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value picturebook_delete( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_picturebook_delete, null, data, null); + + return result; + } + + /** + * 绘本图片删除 + * 完成入库后,可使用该接口实现绘本图删除 + * @param url 图片url + * options 可选参数: + */ + Json::Value picturebook_delete_url( + std::string const & url, + const std::map & options) + { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_picturebook_delete, null, data, null); + + return result; + } + + /** + * 绘本更新 + * 绘本图片更新,通过图片签名 + * @param cont_sign 图片签名 + * options 可选参数: + * tags tag间以逗号分隔,最多2个tag,2个tag无层级关系,检索时支持逻辑运算 + * brief 更新的摘要信息,最长256B。样例:{"name" + */ + Json::Value picturebook_update_cont_sign( + std::string const & cont_sign, + const std::map & options) + { + std::map data; + + data["cont_sign"] = cont_sign; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_picturebook_update, null, data, null); + + return result; + } + + /** + * 本图片搜索 + * 该接口实现单张图片入库,入库时需要同步提交图片及可关联至本地图库的摘要信息 + * @param image 二进制图像数据 + * options 可选参数: + * tags tag间以逗号分隔,最多2个tag,2个tag无层级关系,检索时支持逻辑运算 + * tag_logic 检索时tag之间的逻辑, 0:逻辑and,1:逻辑or + */ + Json::Value picturebook_search( + std::string const & image, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_picturebook_search, null, data, null); + + return result; + } + + /** + * 绘本图片搜索 + * 完成入库后,可使用该接口实现绘本图检索 + * @param url 图片url + * options 可选参数: + * tags tag间以逗号分隔,最多2个tag,2个tag无层级关系,检索时支持逻辑运算 + * tag_logic 检索时tag之间的逻辑, 0:逻辑and,1:逻辑or + */ + Json::Value picturebook_search_url( + std::string const & url, + const std::map & options) + { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_picturebook_search, null, data, null); + + return result; + } + + /** + * 绘本更新 + * 绘本图片更新 + * @param url 图片url + * options 可选参数: + * tags tag间以逗号分隔,最多2个tag,2个tag无层级关系,检索时支持逻辑运算 + * brief 更新的摘要信息,最长256B。样例:{"name" + */ + Json::Value picturebook_update_url( + std::string const & url, + const std::map & options) + { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_picturebook_update, null, data, null); + + return result; + } + + /** + * 绘本图片删除 + * 完成入库后,可使用该接口实现绘本图检索 + * @param cont_sign 图片签名 + * options 可选参数: + */ + Json::Value picturebook_delete_cont_sign( + std::string const & cont_sign, + const std::map & options) + { + std::map data; + + data["cont_sign"] = cont_sign; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_picturebook_delete, null, data, null); + + return result; + } + + /** + * 绘本入库 + * 该接口实现单张图片入库,入库时需要同步提交图片及可关联至本地图库的摘要信息 + * @param image 二进制图像数据 + * @param brief 检索时原样带回 + * options 可选参数: + * tags tag间以逗号分隔,最多2个tag,2个tag无层级关系,检索时支持逻辑运算 + */ + Json::Value picturebook_add( + std::string const & image, + std::string const & brief, + const std::map & options) + { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["brief"] = brief; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_picturebook_add, null, data, null); + + return result; + } + + + }; +} +#endif diff --git a/third/include/aip-cpp-sdk/kg.h b/third/include/aip-cpp-sdk/kg.h new file mode 100644 index 0000000..3ab2dc1 --- /dev/null +++ b/third/include/aip-cpp-sdk/kg.h @@ -0,0 +1,206 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ + +#ifndef __AIP_KG_H__ +#define __AIP_KG_H__ + +#include "base/base.h" + +namespace aip { + + class Kg: public AipBase + { + public: + + + std::string _create_task = + "https://aip.baidubce.com/rest/2.0/kg/v1/pie/task_create"; + + std::string _update_task = + "https://aip.baidubce.com/rest/2.0/kg/v1/pie/task_update"; + + std::string _task_info = + "https://aip.baidubce.com/rest/2.0/kg/v1/pie/task_info"; + + std::string _task_query = + "https://aip.baidubce.com/rest/2.0/kg/v1/pie/task_query"; + + std::string _task_start = + "https://aip.baidubce.com/rest/2.0/kg/v1/pie/task_start"; + + std::string _task_status = + "https://aip.baidubce.com/rest/2.0/kg/v1/pie/task_status"; + + + Kg(const std::string & app_id, const std::string & ak, const std::string & sk): AipBase(app_id, ak, sk) + { + } + + /** + * create_task + * 创建一个新的信息抽取任务 + * @param name 任务名字 + * @param template_content json string 解析模板内容 + * @param input_mapping_file 抓取结果映射文件的路径 + * @param url_pattern url pattern + * @param output_file 输出文件名字 + * options 可选参数: + * limit_count 限制解析数量limit_count为0时进行全量任务,limit_count>0时只解析limit_count数量的页面 + */ + Json::Value create_task( + std::string const & name, + std::string const & template_content, + std::string const & input_mapping_file, + std::string const & url_pattern, + std::string const & output_file, + const std::map & options) + { + std::map data; + + data["name"] = name; + data["template_content"] = template_content; + data["input_mapping_file"] = input_mapping_file; + data["url_pattern"] = url_pattern; + data["output_file"] = output_file; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_create_task, null, data, null); + + return result; + } + + /** + * update_task + * 更新任务配置,在任务重新启动后生效 + * @param id 任务ID + * options 可选参数: + * name 任务名字 + * template_content json string 解析模板内容 + * input_mapping_file 抓取结果映射文件的路径 + * url_pattern url pattern + * output_file 输出文件名字 + */ + Json::Value update_task( + const int & id, + const std::map & options) + { + std::map data; + + data["id"] = std::to_string(id); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_update_task, null, data, null); + + return result; + } + + /** + * task_info + * 根据任务id获取单个任务的详细信息 + * @param id 任务ID + * options 可选参数: + */ + Json::Value task_info( + const int & id, + const std::map & options) + { + std::map data; + + data["id"] = std::to_string(id); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_task_info, null, data, null); + + return result; + } + + /** + * task_query + * 该请求用于菜品识别。即对于输入的一张图片(可正常解码,且长宽比适宜),输出图片的菜品名称、卡路里信息、置信度。 + * options 可选参数: + * id 任务ID,精确匹配 + * name 中缀模糊匹配,abc可以匹配abc,aaabc,abcde等 + * status 要筛选的任务状态 + * page 页码 + * per_page 页码 + */ + Json::Value task_query( + const std::map & options) + { + std::map data; + + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_task_query, null, data, null); + + return result; + } + + /** + * task_start + * 启动一个已经创建的信息抽取任务 + * @param id 任务ID + * options 可选参数: + */ + Json::Value task_start( + const int & id, + const std::map & options) + { + std::map data; + + data["id"] = std::to_string(id); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_task_start, null, data, null); + + return result; + } + + /** + * task_status + * 查询指定的任务的最新执行状态 + * @param id 任务ID + * options 可选参数: + */ + Json::Value task_status( + const int & id, + const std::map & options) + { + std::map data; + + data["id"] = std::to_string(id); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_task_status, null, data, null); + + return result; + } + + + }; +} +#endif \ No newline at end of file diff --git a/third/include/aip-cpp-sdk/machine_translation.h b/third/include/aip-cpp-sdk/machine_translation.h new file mode 100644 index 0000000..b0a67a5 --- /dev/null +++ b/third/include/aip-cpp-sdk/machine_translation.h @@ -0,0 +1,144 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ + +#ifndef __AIP_MACHINE_TRANSLATION_H__ +#define __AIP_MACHINE_TRANSLATION_H__ + +#include "base/base.h" + +namespace aip { + + class Machinetranslation : public AipBase + { + public: + + std::string _pictrans_v1 = + "https://aip.baidubce.com/file/2.0/mt/pictrans/v1"; + std::string _texttrans_v1 = + "https://aip.baidubce.com/rpc/2.0/mt/texttrans/v1"; + std::string _texttrans_with_dict_v1 = + "https://aip.baidubce.com/rpc/2.0/mt/texttrans-with-dict/v1"; + std::string _doc_translation_create_v2 = + "https://aip.baidubce.com/rpc/2.0/mt/v2/doc-translation/create"; + std::string _doc_translation_query_v2 = + "https://aip.baidubce.com/rpc/2.0/mt/v2/doc-translation/query"; + std::string _speech_translation_v2 = + "https://aip.baidubce.com/rpc/2.0/mt/v2/speech-translation"; + + Machinetranslation(const std::string & app_id, const std::string & ak, const std::string & sk) + : AipBase(app_id, ak, sk) + { + } + + /** + * 文本翻译-通用版 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/MT/4kqryjku9 + */ + Json::Value texttrans_v1( + std::string const &from, + std::string const &to, + std::string const &q, + const Json::Value & options) + { + Json::Value data; + data["from"] = from; + data["to"] = to; + data["q"] = q; + merge_json(data, options); + + Json::Value result = this->request_com(_texttrans_v1, data); + + return result; + } + + /** + * 文本翻译-词典版 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/MT/nkqrzmbpc + */ + Json::Value texttrans_with_dict_v1( + std::string const &from, + std::string const &to, + std::string const &q, + const Json::Value & options) + { + Json::Value data; + data["from"] = from; + data["to"] = to; + data["q"] = q; + merge_json(data, options); + + Json::Value result = this->request_com(_texttrans_with_dict_v1, data); + + return result; + } + + /** + * 文档翻译 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/MT/Xky9x5xub + */ + Json::Value doc_translation_create_v2( + std::string const &from, + std::string const &to, + Json::Value & options) + { + Json::Value data; + data["from"] = from; + data["to"] = to; + merge_json(data, options); + + Json::Value result = + this->request_com(_doc_translation_create_v2, data); + + return result; + } + + /** + * 文档翻译-文档状态查询 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/MT/Xky9x5xub + */ + Json::Value doc_translation_query_v2( + std::string const &id) + { + Json::Value data; + data["id"] = id; + Json::Value result = this->request_com(_doc_translation_query_v2, data); + + return result; + } + + /** + * 语音翻译 + * 接口使用文档链接: https://ai.baidu.com/ai-doc/MT/el4cmi76f + */ + Json::Value speech_translation_v2( + std::string const &from, + std::string const &to, + std::string const &voice, + std::string const &format) + { + Json::Value data; + data["from"] = from; + data["to"] = to; + data["voice"] = base64_encode(voice.c_str(), (int) voice.size()); + data["format"] = format; + Json::Value result = + this->request_com(_speech_translation_v2, data); + + return result; + } + + }; +} +#endif \ No newline at end of file diff --git a/third/include/aip-cpp-sdk/nlp.h b/third/include/aip-cpp-sdk/nlp.h new file mode 100644 index 0000000..3cd40c6 --- /dev/null +++ b/third/include/aip-cpp-sdk/nlp.h @@ -0,0 +1,1034 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ + +#ifndef __AIP_NLP_H__ +#define __AIP_NLP_H__ + +#include "base/base.h" + +namespace aip { + + class Nlp: public AipBase + { + public: + + std::string _lexer = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/lexer"; + + std::string _wordembedding = + "https://aip.baidubce.com/rpc/2.0/nlp/v2/word_emb_vec"; + + std::string _depparser = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/depparser"; + + std::string _dnnlm_cn = + "https://aip.baidubce.com/rpc/2.0/nlp/v2/dnnlm_cn"; + + std::string _word_sim_embedding = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/word_emb_sim"; + + std::string _simnet = + "https://aip.baidubce.com/rpc/2.0/nlp/v2/simnet"; + + std::string _comment_tag = + "https://aip.baidubce.com/rpc/2.0/nlp/v2/comment_tag"; + + std::string _word_emb_sim_v2 = + "https://aip.baidubce.com/rpc/2.0/nlp/v2/word_emb_sim"; + + std::string _sentiment_classify_v1 = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/sentiment_classify"; + + std::string _lexer_custom_v1 = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/lexer_custom"; + + std::string _keyword_v1 = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/keyword"; + + std::string _topic_v1 = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/topic"; + + std::string _ecnet_v1 = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/ecnet"; + + std::string _emotion_v1 = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/emotion"; + + std::string _news_summary_v1 = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/news_summary"; + + std::string _address_v1 = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/address"; + + std::string _comment_tag_custom = + "https://aip.baidubce.com/rpc/2.0/nlp/v2/comment_tag_custom"; + + std::string _sentiment_classify_custom = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/sentiment_classify_custom"; + + std::string _couplets = + "https://aip.baidubce.com/rpc/2.0/creation/v1/couplets"; + + std::string _poem = + "https://aip.baidubce.com/rpc/2.0/creation/v1/poem"; + + std::string _entity_level_sentiment = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/entity_level_sentiment"; + + std::string _entity_level_sentiment_add = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/entity_level_sentiment/add"; + + std::string _entity_level_sentiment_delete = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/entity_level_sentiment/delete"; + + std::string _entity_level_sentiment_delete_repo = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/entity_level_sentiment/delete_repo"; + + std::string _entity_level_sentiment_list = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/entity_level_sentiment/list"; + + std::string _entity_level_sentiment_query = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/entity_level_sentiment/query"; + + std::string _topic_phrase = + "https://aip.baidubce.com/rpc/2.0/creation/v1/topic_phrase"; + + std::string _recruitment_cvparser = + "https://aip.baidubce.com/rpc/2.0/recruitment/v1/cvparser"; + + std::string _recruitment_person_post = + "https://aip.baidubce.com/rpc/2.0/recruitment/v1/person_post"; + + std::string _recruitment_personas = + "https://aip.baidubce.com/rpc/2.0/recruitment/v1/personas"; + + std::string _titlepredictor = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/titlepredictor"; + + std::string _depparser_v2 = + "https://aip.baidubce.com/rpc/2.0/nlp/v2/depparser"; + + std::string _bless_creation = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/bless_creation"; + + std::string _entity_analysis = + "https://aip.baidubce.com/rpc/2.0/nlp/v1/entity_analysis"; + + std::string _text_correction = + "https://aip.baidubce.com/rpc/2.0/nlp/v2/text_correction"; + + Nlp(const std::string & app_id, const std::string & ak, const std::string & sk): AipBase(app_id, ak, sk) + { + } + + + /** + * lexer + * 词法分析接口向用户提供分词、词性标注、专名识别三大功能;能够识别出文本串中的基本词汇(分词),对这些词汇进行重组、标注组合后词汇的词性,并进一步识别出命名实体。 + * @param text 待分析文本(目前仅支持UTF8编码),长度不超过65536字节 + * options 可选参数: + */ + Json::Value lexer( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it=options.begin(); it!=options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_lexer, null, data.toStyledString(), null); + + return result; + } + + /** + * wordembedding + * 词向量表示接口提供中文词向量的查询功能。 + * @param word 文本内容(UTF8编码),最大64字节 + * options 可选参数: + */ + Json::Value wordembedding( + std::string const & word, + const std::map & options) + { + Json::Value data; + + data["word"] = word; + + std::map::const_iterator it; + for(it=options.begin(); it!=options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_wordembedding, null, data.toStyledString(), null); + + return result; + } + + /** + * depparser + * 词向量表示接口提供中文词向量的查询功能。 + * @param text 待分析文本(目前仅支持UTF8编码),长度不超过256字节 + * options 可选参数: + * mode 模型选择。默认值为0,可选值mode=0(对应web模型);mode=1(对应query模型) + */ + Json::Value depparser( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it=options.begin(); it!=options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_depparser, null, data.toStyledString(), null); + + return result; + } + + /** + * dnnlm_cn + * 中文DNN语言模型接口用于输出切词结果并给出每个词在句子中的概率值,判断一句话是否符合语言表达习惯。 + * @param text 文本内容(UTF8编码),最大10240字节,不需要切词 + * options 可选参数: + */ + Json::Value dnnlm_cn( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it=options.begin(); it!=options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_dnnlm_cn, null, data.toStyledString(), null); + + return result; + } + + /** + * word_sim_embedding + * 输入两个词,得到两个词的相似度结果。 + * @param word_1 词1(UTF8编码),最大64字节 + * @param word_2 词1(UTF8编码),最大64字节 + * options 可选参数: + * mode 预留字段,可选择不同的词义相似度模型。默认值为0,目前仅支持mode=0 + */ + Json::Value word_sim_embedding( + std::string const & word_1, + std::string const & word_2, + const std::map & options) + { + Json::Value data; + + data["word_1"] = word_1; + data["word_2"] = word_2; + + std::map::const_iterator it; + for(it=options.begin(); it!=options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_word_sim_embedding, null, data.toStyledString(), null); + + return result; + } + + /** + * simnet + * 短文本相似度接口用来判断两个文本的相似度得分。 + * @param text_1 待比较文本1(UTF8编码),最大512字节 + * @param text_2 待比较文本2(UTF8编码),最大512字节 + * options 可选参数: + * model 默认为"BOW",可选"BOW"、"CNN"与"GRNN" + */ + Json::Value simnet( + std::string const & text_1, + std::string const & text_2, + const std::map & options) + { + Json::Value data; + + data["text_1"] = text_1; + data["text_2"] = text_2; + + std::map::const_iterator it; + for(it=options.begin(); it!=options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_simnet, null, data.toStyledString(), null); + + return result; + } + + /** + * comment_tag + * 评论观点抽取接口用来提取一条评论句子的关注点和评论观点,并输出评论观点标签及评论观点极性。 + * @param text 评论内容(UTF8编码),最大10240字节 + * options 可选参数: + * type 评论行业类型,默认为4(餐饮美食) + */ + Json::Value comment_tag( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it=options.begin(); it!=options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_comment_tag, null, data.toStyledString(), null); + + return result; + } + + /** + * 词义相似度 + * @param word_1 词1,最大64字节 + * @param word_2 词2,最大64字节 + * options 可选参数: + */ + Json::Value word_emb_sim_v2( + std::string const & word_1, + std::string const & word_2, + const std::map & options) + { + Json::Value data; + + data["word_1"] = word_1; + data["word_2"] = word_2; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + // 使用GBK编码 + std::map headers; + headers["charset"] = "GBK"; + + Json::Value result = + this->request(_word_emb_sim_v2, null, data.toStyledString(), headers); + + return result; + } + + /** + * 情感倾向分析 + * @param text 文本内容,最大2048字节 + * options 可选参数: + */ + Json::Value sentiment_classify_v1( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_sentiment_classify_v1, null, data.toStyledString(), null); + + return result; + } + + /** + * 词法分析(定制) + * @param text 待分析文本,长度不超过20000字节 + * options 可选参数: + */ + Json::Value lexer_custom_v1( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_lexer_custom_v1, null, data.toStyledString(), null); + + return result; + } + + /** + * 文章标签 + * @param title 文章标题,最大80字节 + * @param content 文章内容,最大65535字节 + * options 可选参数: + */ + Json::Value keyword_v1( + std::string const & title, + std::string const & content, + const std::map & options) + { + Json::Value data; + + data["title"] = title; + data["content"] = content; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_keyword_v1, null, data.toStyledString(), null); + + return result; + } + + /** + * 文章分类 + * @param title 文章标题,最大80字节 + * @param content 文章内容,最大65535字节 + * options 可选参数: + */ + Json::Value topic_v1( + std::string const & title, + std::string const & content, + const std::map & options) + { + Json::Value data; + + data["title"] = title; + data["content"] = content; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_topic_v1, null, data.toStyledString(), null); + + return result; + } + + /** + * 文本纠错 + * @param text 待纠错文本,输入限制550个汉字或英文 + * options 可选参数: + */ + Json::Value ecnet_v1( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_ecnet_v1, null, data.toStyledString(), null); + + return result; + } + + /** + * 文本纠错-高级版 + * @param text 待纠错文本,输入限制550个汉字或英文 + * options 可选参数: + */ + Json::Value text_correction( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_text_correction, null, data.toStyledString(), null); + + return result; + } + + /** + * 对话情绪识别接口 + * @param text 待识别情感文本,输入限制512字节 + * options 可选参数: + * - scene 场景 + * default(默认项-不区分场景), + * talk(闲聊对话-如度秘聊天等), + * task(任务型对话-如导航对话等), + * customer_service(客服对话-如电信/银行客服等) + */ + Json::Value emotion_v1( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_emotion_v1, null, data.toStyledString(), null); + + return result; + } + + /** + * 新闻摘要 + * @param content 字符串仅支持GBK编码,长度需小于3000字符数(即6000字节) + * @param max_summary_len 此数值将作为摘要结果的最大长度。 + * options 可选参数: + * - title: 字符串仅支持GBK编码,长度需小于200字符数 + */ + Json::Value news_summary_v1( + std::string const & content, + int max_summary_len, + const std::map & options) + { + Json::Value data; + + data["content"] = content; + data["max_summary_len"] = max_summary_len; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_news_summary_v1, null, data.toStyledString(), null); + + return result; + } + + /** + * 地址识别 + * @param text 待识别的文本内容,不超过1000字节 + * options 可选参数: + * - confidence: 不设置时默认为-1 + */ + Json::Value address_v1( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_address_v1, null, data.toStyledString(), null); + + return result; + } + + /** + * 评论观点抽取「定制版」 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/ok6z52g8q + */ + Json::Value comment_tag_custom( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_comment_tag_custom, null, data.toStyledString(), null); + + return result; + } + /** + * 情感倾向分析「定制版」 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/zk6z52hds + */ + Json::Value sentiment_classify_custom( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_sentiment_classify_custom, null, data.toStyledString(), null); + + return result; + } + + /** + * 智能春联 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/Ok53wb6dh + */ + Json::Value couplets( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_couplets, null, data.toStyledString(), null); + + return result; + } + + /** + * 智能写诗 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/ak53wc3o3 + */ + Json::Value poem( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_poem, null, data.toStyledString(), null); + + return result; + } + + /** + * 实体抽取与情感倾向分析 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/Fk6z52g04#%E4%B8%BB%E6%8E%A5%E5%8F%A3%EF%BC%88%E5%AE%9E%E4%BD%93%E6%8A%BD%E5%8F%96%E4%B8%8E%E6%83%85%E6%84%9F%E5%80%BE%E5%90%91%E5%88%86%E6%9E%90%EF%BC%89 + */ + Json::Value entity_level_sentiment( + std::string const & title, + std::string const & content, + int const & type, + const std::map & options) + { + Json::Value data; + + data["title"] = title; + data["content"] = content; + data["type"] = type; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_entity_level_sentiment, null, data.toStyledString(), null); + + return result; + } + + /** + * 实体库新增接口 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/Fk6z52g04#%E5%AE%9E%E4%BD%93%E5%BA%93%E6%96%B0%E5%A2%9E%E6%8E%A5%E5%8F%A3 + */ + Json::Value entity_level_sentiment_add( + std::string const & repository, + std::vector entities, + const std::map & options) + { + Json::Value data; + + data["repository"] = repository; + Json::Value& en = data["entities"]; + int index{0}; + for (auto& e : entities) + en[index++] = std::move(e); + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_entity_level_sentiment_add, null, data.toStyledString(), null); + + return result; + } + + /** + * 实体名单删除接口 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/Fk6z52g04#%E5%AE%9E%E4%BD%93%E5%90%8D%E5%8D%95%E5%88%A0%E9%99%A4%E6%8E%A5%E5%8F%A3 + */ + Json::Value entity_level_sentiment_delete( + std::string const & repository, + std::vector entities, + const std::map & options) + { + Json::Value data; + + data["repository"] = repository; + Json::Value& en = data["entities"]; + int index{0}; + for (auto& e : entities) + en[index++] = std::move(e); + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_entity_level_sentiment_delete, null, data.toStyledString(), null); + + return result; + } + + /** + * 实体库删除接口 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/Fk6z52g04#%E5%AE%9E%E4%BD%93%E5%BA%93%E5%88%A0%E9%99%A4%E6%8E%A5%E5%8F%A3 + */ + Json::Value entity_level_sentiment_delete_repo( + std::string const & repository, + const std::map & options) + { + Json::Value data; + + data["repository"] = repository; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_entity_level_sentiment_delete_repo, null, data.toStyledString(), null); + + return result; + } + + /** + * 实体库查询接口 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/Fk6z52g04#%E5%AE%9E%E4%BD%93%E5%BA%93%E6%9F%A5%E8%AF%A2%E6%8E%A5%E5%8F%A3 + */ + Json::Value entity_level_sentiment_list( + const std::map & options) + { + Json::Value data = {}; + + Json::Value result = + this->request(_entity_level_sentiment_list, null, data.toStyledString(), null); + + return result; + } + + /** + * 实体名单查询接口 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/Fk6z52g04#%E5%AE%9E%E4%BD%93%E5%90%8D%E5%8D%95%E6%9F%A5%E8%AF%A2%E6%8E%A5%E5%8F%A3 + */ + Json::Value entity_level_sentiment_query( + std::string const & repository, + const std::map & options) + { + Json::Value data; + + data["repository"] = repository; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_entity_level_sentiment_query, null, data.toStyledString(), null); + + return result; + } + + /** + * 文章主题短语生成 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/9k53w3qob + */ + Json::Value topic_phrase( + std::string const & title, + std::string const & summary, + const std::map & options) + { + Json::Value data; + + data["title"] = title; + data["summary"] = summary; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_topic_phrase, null, data.toStyledString(), null); + + return result; + } + + /** + * 简历解析 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/Xkahvfeqa + */ + Json::Value recruitment_cvparser( + const std::map & resume) + { + Json::Value data; + + std::map::const_iterator it; + for(it = resume.begin(); it != resume.end(); it++) + { + data["resume"][it->first] = it->second; + } + + Json::Value result = + this->request(_recruitment_cvparser, null, data.toStyledString(), null); + + return result; + } + + /** + * 人岗匹配 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/Pkahwzux5 + */ + Json::Value recruitment_person_post( + const std::map & resume, + const std::map & job_description) + { + Json::Value data; + + std::map::const_iterator it_resume; + for(it_resume = resume.begin(); it_resume != resume.end(); it_resume++) + { + data["resume"][it_resume->first] = it_resume->second; + } + + std::map::const_iterator it_job; + for(it_job = job_description.begin(); it_job != job_description.end(); it_job++) + { + data["job_description"][it_job->first] = it_job->second; + } + + Json::Value result = + this->request(_recruitment_person_post, null, data.toStyledString(), null); + + return result; + } + + /** + * 简历画像 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/5kc1kmz3w + */ + Json::Value recruitment_personas( + const std::map & resume) + { + Json::Value data; + + std::map::const_iterator it; + for(it = resume.begin(); it != resume.end(); it++) + { + data["resume"][it->first] = it->second; + } + + Json::Value result = + this->request(_recruitment_personas, null, data.toStyledString(), null); + + return result; + } + + /** + * 文章标题生成 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/0kvc1u1eg + */ + Json::Value titlepredictor( + std::string const & doc, + const std::map & options) + { + Json::Value data; + + data["doc"] = doc; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_titlepredictor, null, data.toStyledString(), null); + + return result; + } + + /** + * 依存句法分析 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/nk6z52eu6 + */ + Json::Value depparser_v2( + std::string const & text, + const std::map & options) + { + + std::map data; + data["text"] = text; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_depparser_v2, null, data, null); + + return result; + } + + /** + * 祝福语生成 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/sl4cg75jk + */ + Json::Value bless_creation( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_bless_creation, null, data.toStyledString(), null); + + return result; + } + + /** + * 实体分析 + * 参数详情参考:https://ai.baidu.com/ai-doc/NLP/al631z295 + */ + Json::Value entity_analysis( + std::string const & text, + const std::map & options) + { + Json::Value data; + + data["text"] = text; + + std::map::const_iterator it; + for(it = options.begin(); it != options.end(); it++) + { + data[it->first] = it->second; + } + + Json::Value result = + this->request(_entity_analysis, null, data.toStyledString(), null); + + return result; + } + }; +} +#endif diff --git a/third/include/aip-cpp-sdk/ocr.h b/third/include/aip-cpp-sdk/ocr.h new file mode 100644 index 0000000..b794529 --- /dev/null +++ b/third/include/aip-cpp-sdk/ocr.h @@ -0,0 +1,3167 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ + +#ifndef __AIP_OCR_H__ +#define __AIP_OCR_H__ + +#include "base/base.h" + +namespace aip { + + class Ocr : public AipBase { + public: + + std::string _medical_detail = + "https://aip.baidubce.com/rest/2.0/ocr/v1/medical_detail"; + + std::string _weight_note = + "https://aip.baidubce.com/rest/2.0/ocr/v1/weight_note"; + + std::string _online_taxi_itinerary = + "https://aip.baidubce.com/rest/2.0/ocr/v1/online_taxi_itinerary"; + std::string _invoice = + "https://aip.baidubce.com/rest/2.0/ocr/v1/invoice"; + + std::string _passport = + "https://aip.baidubce.com/rest/2.0/ocr/v1/passport"; + + std::string _air_ticket = + "https://aip.baidubce.com/rest/2.0/ocr/v1/air_ticket"; + + std::string _household_register = + "https://aip.baidubce.com/rest/2.0/ocr/v1/household_register"; + + std::string _vehicle_certificate = + "https://aip.baidubce.com/rest/2.0/ocr/v1/vehicle_certificate"; + + std::string _vehicle_invoice = + "https://aip.baidubce.com/rest/2.0/ocr/v1/vehicle_invoice"; + + std::string _qrcode = + "https://aip.baidubce.com/rest/2.0/ocr/v1/qrcode"; + + std::string _doc_analysis_office = + "https://aip.baidubce.com/rest/2.0/ocr/v1/doc_analysis_office"; + + std::string _handwriting = + "https://aip.baidubce.com/rest/2.0/ocr/v1/handwriting"; + + std::string _doc_analysis = + "https://aip.baidubce.com/rest/2.0/ocr/v1/doc_analysis"; + + std::string _meter = + "https://aip.baidubce.com/rest/2.0/ocr/v1/meter"; + + std::string _webimage_loc = + "https://aip.baidubce.com/rest/2.0/ocr/v1/webimage_loc"; + + std::string _seal = + "https://aip.baidubce.com/rest/2.0/ocr/v1/seal"; + + std::string _general_basic = + "https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic"; + + std::string _accurate_basic = + "https://aip.baidubce.com/rest/2.0/ocr/v1/accurate_basic"; + + std::string _general = + "https://aip.baidubce.com/rest/2.0/ocr/v1/general"; + + std::string _accurate = + "https://aip.baidubce.com/rest/2.0/ocr/v1/accurate"; + + std::string _general_enhanced = + "https://aip.baidubce.com/rest/2.0/ocr/v1/general_enhanced"; + + std::string _webimage = + "https://aip.baidubce.com/rest/2.0/ocr/v1/webimage"; + + std::string _idcard = + "https://aip.baidubce.com/rest/2.0/ocr/v1/idcard"; + + std::string _bankcard = + "https://aip.baidubce.com/rest/2.0/ocr/v1/bankcard"; + + std::string _driving_license = + "https://aip.baidubce.com/rest/2.0/ocr/v1/driving_license"; + + std::string _vehicle_license = + "https://aip.baidubce.com/rest/2.0/ocr/v1/vehicle_license"; + + std::string _license_plate = + "https://aip.baidubce.com/rest/2.0/ocr/v1/license_plate"; + + std::string _business_license = + "https://aip.baidubce.com/rest/2.0/ocr/v1/business_license"; + + std::string _receipt = + "https://aip.baidubce.com/rest/2.0/ocr/v1/receipt"; + + std::string _table_recognize = + "https://aip.baidubce.com/rest/2.0/solution/v1/form_ocr/request"; + + std::string _table_result_get = + "https://aip.baidubce.com/rest/2.0/solution/v1/form_ocr/get_request_result"; + + std::string _vat_invoice = + "https://aip.baidubce.com/rest/2.0/ocr/v1/vat_invoice"; + + std::string _taxi_receipt = + "https://aip.baidubce.com/rest/2.0/ocr/v1/taxi_receipt"; + + std::string _vin_code = + "https://aip.baidubce.com/rest/2.0/ocr/v1/vin_code"; + + std::string _numbers = + "https://aip.baidubce.com/rest/2.0/ocr/v1/numbers"; + + std::string _train_ticket = + "https://aip.baidubce.com/rest/2.0/ocr/v1/train_ticket"; + + std::string _lottery_v1 = + "https://aip.baidubce.com/rest/2.0/ocr/v1/lottery"; + + std::string _insurance_documents_v1 = + "https://aip.baidubce.com/rest/2.0/ocr/v1/insurance_documents"; + + std::string _taiwan_exitentrypermit_v1 = + "https://aip.baidubce.com/rest/2.0/ocr/v1/taiwan_exitentrypermit"; + + std::string _HK_Macau_exitentrypermit_v1 = + "https://aip.baidubce.com/rest/2.0/ocr/v1/HK_Macau_exitentrypermit"; + + std::string _birth_certificate_v1 = + "https://aip.baidubce.com/rest/2.0/ocr/v1/birth_certificate"; + + std::string _business_card_v1 = + "https://aip.baidubce.com/rest/2.0/ocr/v1/business_card"; + + std::string _quota_invoice_v1 = + "https://aip.baidubce.com/rest/2.0/ocr/v1/quota_invoice"; + + std::string _recognise_iocr_v1 = + "https://aip.baidubce.com/rest/2.0/solution/v1/iocr/recognise"; + + std::string _recognise_iocr_finance = + "https://aip.baidubce.com/rest/2.0/solution/v1/iocr/recognise/finance"; + + std::string _bus_ticket = + "https://aip.baidubce.com/rest/2.0/ocr/v1/bus_ticket"; + + std::string _toll_invoice = + "https://aip.baidubce.com/rest/2.0/ocr/v1/toll_invoice"; + + std::string _multi_card_classify = + "https://aip.baidubce.com/rest/2.0/ocr/v1/multi_card_classify"; + + std::string _intelligent_ocr = + "https://aip.baidubce.com/rest/2.0/ocr/v1/intelligent_ocr"; + + std::string _medical_record = + "https://aip.baidubce.com/rest/2.0/ocr/v1/medical_record"; + + std::string _medical_statement = + "https://aip.baidubce.com/rest/2.0/ocr/v1/medical_statement"; + + std::string _ferry_ticket = + "https://aip.baidubce.com/rest/2.0/ocr/v1/ferry_ticket"; + + std::string _used_vehicle_invoice = + "https://aip.baidubce.com/rest/2.0/ocr/v1/used_vehicle_invoice"; + + std::string _multi_idcard = + "https://aip.baidubce.com/rest/2.0/ocr/v1/multi_idcard"; + + std::string _travel_card = + "https://aip.baidubce.com/rest/2.0/ocr/v1/travel_card"; + + std::string _social_security_card = + "https://aip.baidubce.com/rest/2.0/ocr/v1/social_security_card"; + + std::string _medical_report_detection = + "https://aip.baidubce.com/rest/2.0/ocr/v1/medical_report_detection"; + + std::string _medical_recipts_classify = + "https://aip.baidubce.com/rest/2.0/ocr/v1/medical_recipts_classify"; + + std::string _waybill = + "https://aip.baidubce.com/rest/2.0/ocr/v1/waybill"; + + std::string _medical_summary = + "https://aip.baidubce.com/rest/2.0/ocr/v1/medical_summary"; + + std::string _shopping_receipt = + "https://aip.baidubce.com/rest/2.0/ocr/v1/shopping_receipt"; + + std::string _road_transport_certificate = + "https://aip.baidubce.com/rest/2.0/ocr/v1/road_transport_certificate"; + + std::string _table = + "https://aip.baidubce.com/rest/2.0/ocr/v1/table"; + + std::string _remove_handwriting = + "https://aip.baidubce.com/rest/2.0/ocr/v1/remove_handwriting"; + + std::string _doc_crop_enhance = + "https://aip.baidubce.com/rest/2.0/ocr/v1/doc_crop_enhance"; + + std::string _health_code = + "https://aip.baidubce.com/rest/2.0/ocr/v1/health_code"; + + std::string _covid_test = + "https://aip.baidubce.com/rest/2.0/ocr/v1/covid_test"; + + std::string _medical_prescription = + "https://aip.baidubce.com/rest/2.0/ocr/v1/medical_prescription"; + + std::string _medical_outpatient = + "https://aip.baidubce.com/rest/2.0/ocr/v1/medical_outpatient"; + + std::string _medical_summary_diagnosis = + "https://aip.baidubce.com/rest/2.0/ocr/v1/medical_summary_diagnosis"; + + std::string _health_report = + "https://aip.baidubce.com/rest/2.0/ocr/v1/health_report"; + + std::string _doc_convert_request_v1 = "https://aip.baidubce.com/rest/2.0/ocr/v1/doc_convert/request"; + + std::string _doc_convert_result_v1 = "https://aip.baidubce.com/rest/2.0/ocr/v1/doc_convert/get_request_result"; + + std::string _bank_receipt_new = "https://aip.baidubce.com/rest/2.0/ocr/v1/bank_receipt_new"; + + std::string _marriage_certificate = "https://aip.baidubce.com/rest/2.0/ocr/v1/marriage_certificate"; + + std::string _hk_macau_taiwan_exitentrypermit = + "https://aip.baidubce.com/rest/2.0/ocr/v1/hk_macau_taiwan_exitentrypermit"; + + Ocr(const std::string &app_id, const std::string &ak, const std::string &sk) : AipBase(app_id, ak, sk) { + } + + /** + * general_basic + * 用户向服务请求识别某张图中的所有文字 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * language_type 识别语言类型,默认为CHN_ENG。可选值包括:
- CHN_ENG:中英文混合;
- ENG:英文;
- POR:葡萄牙语;
- FRE:法语;
- GER:德语;
- ITA:意大利语;
- SPA:西班牙语;
- RUS:俄语;
- JAP:日语;
- KOR:韩语; + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:
- true:检测朝向;
- false:不检测朝向。 + * detect_language 是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) + * probability 是否返回识别结果中每一行的置信度 + */ + Json::Value general_basic( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_general_basic, null, data, null); + + return result; + } + + /** + * general_basic_url + * 用户向服务请求识别某张图中的所有文字 + * @param url 图片完整URL,URL长度不超过1024字节,URL对应的图片base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式,当image字段存在时url字段失效 + + * options 可选参数: + * language_type 识别语言类型,默认为CHN_ENG。可选值包括:
- CHN_ENG:中英文混合;
- ENG:英文;
- POR:葡萄牙语;
- FRE:法语;
- GER:德语;
- ITA:意大利语;
- SPA:西班牙语;
- RUS:俄语;
- JAP:日语;
- KOR:韩语; + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:
- true:检测朝向;
- false:不检测朝向。 + * detect_language 是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) + * probability 是否返回识别结果中每一行的置信度 + */ + Json::Value general_basic_url( + std::string const &url, + const std::map &options) { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_general_basic, null, data, null); + + return result; + } + + /** + * 通用文字识别(标准版) + * https://ai.baidu.com/ai-doc/OCR/zk3h7xz52 + * + * @param pdf + * @param options + * @return + */ + Json::Value general_basic_pdf( + std::string const &pdf, + const std::map &options) { + std::map data; + + data["pdf_file"] = base64_encode(pdf.c_str(), (int) pdf.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_general_basic, null, data, null); + + return result; + } + + /** + * accurate_basic + * 用户向服务请求识别某张图中的所有文字,相对于通用文字识别该产品精度更高,但是没有免费额度,如果您需要使用该产品,您可以在产品页面点击合作咨询或加入文字识别的官网QQ群:631977213向管理员申请试用。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:
- true:检测朝向;
- false:不检测朝向。 + * probability 是否返回识别结果中每一行的置信度 + */ + Json::Value accurate_basic( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_accurate_basic, null, data, null); + + return result; + } + + /** + * 通用文字识别(高精度版) + * https://ai.baidu.com/ai-doc/OCR/1k3h7y3db + * + * @param url + * @param options + * @return + */ + Json::Value accurate_basic_url( + std::string const &url, + const std::map &options) { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_accurate_basic, null, data, null); + + return result; + } + + /** + * 通用文字识别(高精度版) + * https://ai.baidu.com/ai-doc/OCR/1k3h7y3db + * + * @param pdf + * @param options + * @return + */ + Json::Value accurate_basic_pdf( + std::string const &pdf, + const std::map &options) { + std::map data; + + data["pdf_file"] = base64_encode(pdf.c_str(), (int) pdf.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_accurate_basic, null, data, null); + + return result; + } + + /** + * general + * 用户向服务请求识别某张图中的所有文字,并返回文字在图中的位置信息。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * recognize_granularity 是否定位单字符位置,big:不定位单字符位置,默认值;small:定位单字符位置 + * language_type 识别语言类型,默认为CHN_ENG。可选值包括:
- CHN_ENG:中英文混合;
- ENG:英文;
- POR:葡萄牙语;
- FRE:法语;
- GER:德语;
- ITA:意大利语;
- SPA:西班牙语;
- RUS:俄语;
- JAP:日语;
- KOR:韩语; + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:
- true:检测朝向;
- false:不检测朝向。 + * detect_language 是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) + * vertexes_location 是否返回文字外接多边形顶点位置,不支持单字位置。默认为false + * probability 是否返回识别结果中每一行的置信度 + */ + Json::Value general( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_general, null, data, null); + + return result; + } + + /** + * general_url + * 用户向服务请求识别某张图中的所有文字,并返回文字在图中的位置信息。 + * @param url 图片完整URL,URL长度不超过1024字节,URL对应的图片base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式,当image字段存在时url字段失效 + + * options 可选参数: + * recognize_granularity 是否定位单字符位置,big:不定位单字符位置,默认值;small:定位单字符位置 + * language_type 识别语言类型,默认为CHN_ENG。可选值包括:
- CHN_ENG:中英文混合;
- ENG:英文;
- POR:葡萄牙语;
- FRE:法语;
- GER:德语;
- ITA:意大利语;
- SPA:西班牙语;
- RUS:俄语;
- JAP:日语;
- KOR:韩语; + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:
- true:检测朝向;
- false:不检测朝向。 + * detect_language 是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) + * vertexes_location 是否返回文字外接多边形顶点位置,不支持单字位置。默认为false + * probability 是否返回识别结果中每一行的置信度 + */ + Json::Value general_url( + std::string const &url, + const std::map &options) { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_general, null, data, null); + + return result; + } + + /** + * 通用文字识别(标准含位置版) + * https://ai.baidu.com/ai-doc/OCR/vk3h7y58v + * + * @param pdf + * @param options + * @return + */ + Json::Value general_pdf( + std::string const &pdf, + const std::map &options) { + std::map data; + + data["pdf_file"] = base64_encode(pdf.c_str(), (int) pdf.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_general, null, data, null); + + return result; + } + + /** + * accurate + * 用户向服务请求识别某张图中的所有文字,相对于通用文字识别(含位置信息版)该产品精度更高,但是没有免费额度,如果您需要使用该产品,您可以在产品页面点击合作咨询或加入文字识别的官网QQ群:631977213向管理员申请试用。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * recognize_granularity 是否定位单字符位置,big:不定位单字符位置,默认值;small:定位单字符位置 + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:
- true:检测朝向;
- false:不检测朝向。 + * vertexes_location 是否返回文字外接多边形顶点位置,不支持单字位置。默认为false + * probability 是否返回识别结果中每一行的置信度 + */ + Json::Value accurate( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_accurate, null, data, null); + + return result; + } + + /** + * 通用文字识别(高精度含位置版) + * https://ai.baidu.com/ai-doc/OCR/tk3h7y2aq + * + * @param url + * @param options + * @return + */ + Json::Value accurate_url( + std::string const &url, + const std::map &options) { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_accurate, null, data, null); + + return result; + } + + /** + * 通用文字识别(高精度含位置版) + * https://ai.baidu.com/ai-doc/OCR/tk3h7y2aq + * + * @param pdf + * @param options + * @return + */ + Json::Value accurate_pdf( + std::string const &pdf, + const std::map &options) { + std::map data; + + data["pdf_file"] = base64_encode(pdf.c_str(), (int) pdf.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_accurate, null, data, null); + + return result; + } + + /** + * general_enhanced + * 某些场景中,图片中的中文不光有常用字,还包含了生僻字,这时用户需要对该图进行文字识别,应使用通用文字识别(含生僻字版)。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * language_type 识别语言类型,默认为CHN_ENG。可选值包括:
- CHN_ENG:中英文混合;
- ENG:英文;
- POR:葡萄牙语;
- FRE:法语;
- GER:德语;
- ITA:意大利语;
- SPA:西班牙语;
- RUS:俄语;
- JAP:日语;
- KOR:韩语; + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:
- true:检测朝向;
- false:不检测朝向。 + * detect_language 是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) + * probability 是否返回识别结果中每一行的置信度 + */ + Json::Value general_enhanced( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_general_enhanced, null, data, null); + + return result; + } + + /** + * general_enhanced_url + * 某些场景中,图片中的中文不光有常用字,还包含了生僻字,这时用户需要对该图进行文字识别,应使用通用文字识别(含生僻字版)。 + * @param url 图片完整URL,URL长度不超过1024字节,URL对应的图片base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式,当image字段存在时url字段失效 + + * options 可选参数: + * language_type 识别语言类型,默认为CHN_ENG。可选值包括:
- CHN_ENG:中英文混合;
- ENG:英文;
- POR:葡萄牙语;
- FRE:法语;
- GER:德语;
- ITA:意大利语;
- SPA:西班牙语;
- RUS:俄语;
- JAP:日语;
- KOR:韩语; + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:
- true:检测朝向;
- false:不检测朝向。 + * detect_language 是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) + * probability 是否返回识别结果中每一行的置信度 + */ + Json::Value general_enhanced_url( + std::string const &url, + const std::map &options) { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_general_enhanced, null, data, null); + + return result; + } + + /** + * webimage + * 用户向服务请求识别一些网络上背景复杂,特殊字体的文字。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:
- true:检测朝向;
- false:不检测朝向。 + * detect_language 是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) + */ + Json::Value webimage( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_webimage, null, data, null); + + return result; + } + + /** + * webimage_url + * 用户向服务请求识别一些网络上背景复杂,特殊字体的文字。 + * @param url 图片完整URL,URL长度不超过1024字节,URL对应的图片base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式,当image字段存在时url字段失效 + + * options 可选参数: + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:
- true:检测朝向;
- false:不检测朝向。 + * detect_language 是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) + */ + Json::Value webimage_url( + std::string const &url, + const std::map &options) { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_webimage, null, data, null); + + return result; + } + + /** + * idcard + * 用户向服务请求识别身份证,身份证识别包括正面和背面。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * @param id_card_side front:身份证正面;back:身份证背面 + * options 可选参数: + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:
- true:检测朝向;
- false:不检测朝向。 + * detect_risk 是否开启身份证风险类型(身份证复印件、临时身份证、身份证翻拍、修改过的身份证)功能,默认不开启,即:false。可选值:true-开启;false-不开启 + */ + Json::Value idcard( + std::string const &image, + std::string const &id_card_side, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["id_card_side"] = id_card_side; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_idcard, null, data, null); + + return result; + } + + /** + * 身份证识别 + * https://ai.baidu.com/ai-doc/OCR/rk3h7xzck + * + * @param url + * @param id_card_side + * @param options + * @return + */ + Json::Value idcard_url( + std::string const &url, + std::string const &id_card_side, + const std::map &options) { + std::map data; + + data["url"] = url; + data["id_card_side"] = id_card_side; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_idcard, null, data, null); + + return result; + } + + /** + * bankcard + * 识别银行卡并返回卡号和发卡行。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value bankcard( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_bankcard, null, data, null); + + return result; + } + + /** + * driving_license + * 对机动车驾驶证所有关键字段进行识别 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:
- true:检测朝向;
- false:不检测朝向。 + */ + Json::Value driving_license( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_driving_license, null, data, null); + + return result; + } + + /** + * 驾驶证识别 + * https://ai.baidu.com/ai-doc/OCR/Vk3h7xzz7 + * + * @param url + * @param options + * @return + */ + Json::Value driving_license_url( + std::string const &url, + const std::map &options) { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_driving_license, null, data, null); + + return result; + } + + /** + * vehicle_license + * 对机动车行驶证正本所有关键字段进行识别 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:
- true:检测朝向;
- false:不检测朝向。 + * accuracy normal 使用快速服务,1200ms左右时延;缺省或其它值使用高精度服务,1600ms左右时延 + */ + Json::Value vehicle_license( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_vehicle_license, null, data, null); + + return result; + } + + /** + * 行驶证识别 + * https://ai.baidu.com/ai-doc/OCR/yk3h7y3ks + * + * @param url + * @param options + * @return + */ + Json::Value vehicle_license_url( + std::string const &url, + const std::map &options) { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_vehicle_license, null, data, null); + + return result; + } + + /** + * license_plate + * 识别机动车车牌,并返回签发地和号牌。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value license_plate( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_license_plate, null, data, null); + + return result; + } + + /** + * business_license + * 识别营业执照,并返回关键字段的值,包括单位名称、法人、地址、有效期、证件编号、社会信用代码等。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value business_license( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_business_license, null, data, null); + + return result; + } + + /** + * receipt + * 用户向服务请求识别医疗票据、发票、的士票、保险保单等票据类图片中的所有文字,并返回文字在图中的位置信息。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + * recognize_granularity 是否定位单字符位置,big:不定位单字符位置,默认值;small:定位单字符位置 + * probability 是否返回识别结果中每一行的置信度 + * accuracy normal 使用快速服务,1200ms左右时延;缺省或其它值使用高精度服务,1600ms左右时延 + */ + Json::Value receipt( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_receipt, null, data, null); + + return result; + } + + /** + * table_recognize + * 自动识别表格线及表格内容,结构化输出表头、表尾及每个单元格的文字内容。表格文字识别接口为异步接口,分为两个API:提交请求接口、获取结果接口。 + * @param image 图像文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value table_recognize( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_table_recognize, null, data, null); + + return result; + } + + /** + * table_result_get + * 获取表格文字识别结果 + * @param request_id 发送表格文字识别请求时返回的request id + * options 可选参数: + * result_type 期望获取结果的类型,取值为“excel”时返回xls文件的地址,取值为“json”时返回json格式的字符串,默认为”excel” + */ + Json::Value table_result_get( + std::string const &request_id, + const std::map &options) { + std::map data; + + data["request_id"] = request_id; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_table_result_get, null, data, null); + + return result; + } + + /** + * + * 增值税发票识别 + * @param image 图像二进制内容 + * options 可选参数: + */ + Json::Value vatInvoice( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_vat_invoice, null, data, null); + + return result; + } + + /** + * + * 增值税发票识别 + * @param image 发票图像URL地址 + * options 可选参数: + */ + Json::Value vatInvoiceUrl( + std::string const &image, + const std::map &options) { + std::map data; + + data["url"] = image; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_vat_invoice, null, data, null); + + return result; + } + + /** + * + * 增值税发票识别 + * @param image 发票pdf文件二进制数据 + * options 可选参数: + */ + Json::Value vatInvoicePdf( + std::string const &image, + const std::map &options) { + std::map data; + + data["pdf_file"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_vat_invoice, null, data, null); + + return result; + } + + /** + * + * 出租车发票识别 + * @param image 图像二进制内容 + * options 可选参数: + */ + Json::Value taxiReceipt( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_taxi_receipt, null, data, null); + + return result; + } + + /** + * + * 出租车票识别 + * @param image 发票图像URL地址 + * options 可选参数: + */ + Json::Value taxiReceiptUrl( + std::string const &image, + const std::map &options) { + std::map data; + + data["url"] = image; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_taxi_receipt, null, data, null); + + return result; + } + + /** + * + * vin码识别 + * @param image 图像二进制内容 + * options 可选参数: + */ + Json::Value vinCode( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_vin_code, null, data, null); + + return result; + } + + /** + * + * vin 码识别 + * @param image 发票图像URL地址 + * options 可选参数: + */ + Json::Value vinCodeUrl( + std::string const &image, + const std::map &options) { + std::map data; + + data["url"] = image; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_vin_code, null, data, null); + + return result; + } + + /** + * + * 火车票票识别 + * @param image 图像二进制内容 + * options 可选参数: + */ + Json::Value trainTicket( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_train_ticket, null, data, null); + + return result; + } + + /** + * + * 火车票票识别 + * @param image 发票图像URL地址 + * options 可选参数: + */ + Json::Value trainTicketUrl( + std::string const &image, + const std::map &options) { + std::map data; + + data["url"] = image; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_train_ticket, null, data, null); + + return result; + } + + /** + * + * 数字识别 + * @param image 图像二进制内容 + * options 可选参数: + */ + Json::Value numbers( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_numbers, null, data, null); + + return result; + } + + + /** + * 印章识别 + * 检测并识别合同文件或常用票据中的印章,输出文字内容、印章位置信息以及相关置信度,已支持圆形章、椭圆形章、方形章等常见印章检测与识别 + * @param image 二进制图像数据 + * options 可选参数: + */ + Json::Value seal( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_seal, null, data, null); + + return result; + } + + /** + * 网络图片文字识别(含位置版) + * 支持识别艺术字体或背景复杂的文字内容,除文字信息外,还可返回每行文字的位置信息、行置信度,以及单字符内容和位置等。 + * @param image 二进制图像数据 + * options 可选参数: + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度 + * probability 是否返回每行识别结果的置信度。默认为false + * poly_location 是否返回文字所在区域的外接四边形的4个点坐标信息。默认为false + * recognize_granularity 是否定位单字符位置,big:不定位单字符位置,默认值;small:定位单字符位置 + */ + Json::Value webimageloc( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_webimage_loc, null, data, null); + + return result; + } + + /** + * 网络图片文字识别(含位置版) + * 支持识别艺术字体或背景复杂的文字内容,除文字信息外,还可返回每行文字的位置信息、行置信度,以及单字符内容和位置等。 + * @param url 图片完整URL + * options 可选参数: + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度 + * probability 是否返回每行识别结果的置信度。默认为false + * poly_location 是否返回文字所在区域的外接四边形的4个点坐标信息。默认为false + * recognize_granularity 是否定位单字符位置,big:不定位单字符位置,默认值;small:定位单字符位置 + */ + Json::Value webimagelocurl( + std::string const &url, + const std::map &options) { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_webimage_loc, null, data, null); + + return result; + } + + /** + * 仪器仪表盘读数识别 + * 适用于不同品牌、不同型号的仪器仪表盘读数识别,广泛适用于各类血糖仪、血压仪、燃气表、电表等,可识别表盘上的数字、英文、符号,支持液晶屏、字轮表等表型。 + * @param image 二进制图像数据 + * options 可选参数: + * probability 是否返回每行识别结果的置信度。默认为false + * poly_location 位置信息返回形式,默认:false + */ + Json::Value meter( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_meter, null, data, null); + + return result; + } + + /** + * 试卷分析与识别 + * 可对文档版面进行分析,输出图、表、标题、文本的位置,并输出分版块内容的OCR识别结果,支持中、英两种语言,手写、印刷体混排多种场景 + * @param url 图片url + * options 可选参数: + * language_type 识别语言类型,默认为CHN_ENG 可选值包括:CHN_ENG:中英文 ENG:英文 + * result_type 返回识别结果是按单行结果返回,还是按单字结果返回,默认为big。 + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度 + * line_probability 是否返回每行识别结果的置信度。默认为false + * words_type 文字类型。默认:印刷文字识别 + * layout_analysis 是否分析文档版面:包括图、表、标题、段落的分析输出 + */ + Json::Value docanalysisurl( + std::string const &url, + const std::map &options) { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_doc_analysis, null, data, null); + + return result; + } + + /** + * 手写文字识别 + * 支持对图片中的手写中文、手写数字进行检测和识别,针对不规则的手写字体进行专项优化,识别准确率可达90%以上 + * @param image 二进制图像数据 + * options 可选参数: + * recognize_granularity 是否定位单字符位置,big:不定位单字符位置,默认值;small:定位单字符位置 + * probability 是否返回识别结果中每一行的置信度,默认为false,不返回置信度 + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括 + */ + Json::Value handwriting( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_handwriting, null, data, null); + + return result; + } + + /** + * 办公文档识别 + * 可对办公类文档版面进行分析,输出图、表、标题、文本的位置,并输出分版块内容的OCR识别结果,支持中、英两种语言,手写、印刷体混排多种场景。 + * @param image 二进制图像数据 + * options 可选参数: + * language_type 识别语言类型,默认为CHN_ENG 可选值包括:CHN_ENG:中英文 ENG:英文 + * result_type 返回识别结果是按单行结果返回,还是按单字结果返回,默认为big。 + * detect_direction 是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度 + * line_probability 是否返回每行识别结果的置信度。默认为false + * words_type 文字类型。默认:印刷文字识别 + * layout_analysis 是否分析文档版面:包括图、表、标题、段落的分析输出 + * erase_seal 是否先擦除水印、印章后再识别文档 + */ + Json::Value docanalysisoffice( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_doc_analysis_office, null, data, null); + + return result; + } + + /** + * 二维码识别 + * 对图片中的二维码、条形码进行检测和识别,返回存储的文字信息 + * @param image 二进制图像数据 + * options 可选参数: + * recognize_granularity 是否定位单字符位置,big:不定位单字符位置,默认值;small:定位单字符位置&probability + */ + Json::Value qrcode( + std::string const &image, + const std::map &options) { + std::map data; + + data["image"] = base64_encode(image.c_str(), (int) image.size()); + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_qrcode, null, data, null); + + return result; + } + + /** + * 仪器仪表盘读数识别 + * 适用于不同品牌、不同型号的仪器仪表盘读数识别,广泛适用于各类血糖仪、血压仪、燃气表、电表等,可识别表盘上的数字、英文、符号,支持液晶屏、字轮表等表型。 + * @param url 图像url地址 + * options 可选参数: + * probability 是否返回每行识别结果的置信度。默认为false + * poly_location 位置信息返回形式,默认:false + */ + Json::Value meterurl( + std::string const &url, + const std::map &options) { + std::map data; + + data["url"] = url; + + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_meter, null, data, null); + + return result; + } + + /** + * 二维码识别 + * 对图片中的二维码、条形码进行检测和识别,返回存储的文字信息 + * @param url 图片完整URL + * options 可选参数: + + */ + Json::Value qrcodeUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_qrcode, null, data, null); + return result; + } + + /** + * 试卷分析与识别 + * 支持对车辆合格证的23个关键字段进行结构化识别 + * @param image 二进制图像数据 + * options 可选参数: + * multi_detect 控制是否开启多航班信息识别功能,默认值:false + */ + Json::Value docAnalysis( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_doc_analysis, null, data, null); + return result; + } + + /** + * 试卷分析与识别 + * 支持对车辆合格证的23个关键字段进行结构化识别 + * @param url 图片完整URL + * options 可选参数: + * multi_detect 控制是否开启多航班信息识别功能,默认值:false + */ + Json::Value docAnalysisUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_doc_analysis, null, data, null); + return result; + } + + /** + * 机动车销售发票 + * 支持对机动车销售发票的26个关键字段进行结构化识别, + * @param image 二进制图像数据 + * options 可选参数: + + */ + Json::Value vehicleInvoice( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_vehicle_invoice, null, data, null); + return result; + } + + /** + * 机动车销售发票 + * 支持对机动车销售发票的26个关键字段进行结构化识别, + * @param url 图片完整URL + * options 可选参数: + + */ + Json::Value vehicleInvoiceUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_vehicle_invoice, null, data, null); + return result; + } + + /** + * 车辆合格证 + * 支持对车辆合格证的23个关键字段进行结构化识别,包括合格证编号、发证日期、车辆制造企业名、车辆品牌、车辆名称、车辆型号、车架号、车身颜色、 + 发动机型号、发动机号、燃料种类、排量、功率、排放标准、轮胎数、轴距、轴数、转向形式、总质量、整备质量、驾驶室准乘人数、最高设计车速、车辆制造日期 + * @param image 二进制图像数据 + * options 可选参数: + * language_type 识别语言类型,默认为CHN_ENG * result_type 返回识别结果是按单行结果返回,还是按单字结果返回,默认为big + * detect_direction 是否检测图像朝向,默认不检测,即:false * line_probability 是否返回每行识别结果的置信度。默认为false + * words_type 文字类型。 + */ + Json::Value vehicleCertificate( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_vehicle_certificate, null, data, null); + return result; + } + + /** + * 车辆合格证 + * 支持对车辆合格证的23个关键字段进行结构化识别,包括合格证编号、发证日期、车辆制造企业名、车辆品牌、车辆名称、车辆型号、车架号、车身颜色、发动机型号 + 、发动机号、燃料种类、排量、功率、排放标准、轮胎数、轴距、轴数、转向形式、总质量、整备质量、驾驶室准乘人数、最高设计车速、车辆制造日期 + * @param url 图片完整URL + * options 可选参数: + * language_type 识别语言类型,默认为CHN_ENG * result_type 返回识别结果是按单行结果返回,还是按单字结果返回,默认为big + * detect_direction 是否检测图像朝向,默认不检测,即:false * line_probability 是否返回每行识别结果的置信度。默认为false + * words_type 文字类型。 + */ + Json::Value vehicleCertificateUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_vehicle_certificate, null, data, null); + return result; + } + + /** + * 户口本识别 + * 支持对户口本内常住人口登记卡的全部 22 个字段进行结构化识别, + * @param image 二进制图像数据 + * options 可选参数: + + */ + Json::Value householdRegister( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_household_register, null, data, null); + return result; + } + + /** + * 户口本识别 + * 支持对户口本内常住人口登记卡的全部 22 个字段进行结构化识别, + * @param url 图片完整URL + * options 可选参数: + + */ + Json::Value householdRegisterUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_household_register, null, data, null); + return result; + } + + /** + * 手写文字识别 + * 支持对图片中的手写中文、手写数字进行检测和识别, + * @param url 图片完整URL + * options 可选参数: + * recognize_granularity 是否定位单字符位置, * probability 是否返回识别结果中每一行的置信度,默认为false,不返回置信度 + * detect_direction 是否检测图像朝向,默认不检测,即:false + */ + Json::Value handwritingUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_handwriting, null, data, null); + return result; + } + + /** + * 飞机行程单识别 + * 支持对飞机行程单的24个字段进行结构化识别,包括电子客票号、印刷序号、姓名、始发站、目的站、航班号、日期、时间、票价、身份证号、承运人、民航发展基金、 + 保险费、燃油附加费、其他税费、合计金额、填开日期、订票渠道、客票级别、座位等级、销售单位号、签注、免费行李、验证码。 同时,支持单张行程单上的多航班信息识别。 + * @param image 二进制图像数据 + * options 可选参数: + * multi_detect 控制是否开启多航班信息识别功能,默认值:false + */ + Json::Value airTicket( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_air_ticket, null, data, null); + return result; + } + + /** + * 飞机行程单识别 + * 支持对飞机行程单的24个字段进行结构化识别,包括电子客票号、印刷序号、姓名、始发站、目的站、航班号、日期、时间、票价、身份证号、承运人、 + 民航发展基金、保险费、燃油附加费、其他税费、合计金额、填开日期、订票渠道、客票级别、座位等级、销售单位号、签注、免费行李、验证码。 同时, + 支持单张行程单上的多航班信息识别。 + * @param url 图片完整URL + * options 可选参数: + * multi_detect 控制是否开启多航班信息识别功能,默认值:false + */ + Json::Value airTicketUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_air_ticket, null, data, null); + return result; + } + + /** + * 通用机打发票 + * 支持对图片中的手写中文、手写数字进行检测和识别, + * @param image 二进制图像数据 + * options 可选参数: + * location 是否输出位置信息,true:输出位置信息,false:不输出位置信息,默认false + */ + Json::Value invoice( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_invoice, null, data, null); + return result; + } + + /** + * 通用机打发票 + * 支持对图片中的手写中文、手写数字进行检测和识别, + * @param url 图片完整URL + * options 可选参数: + * location 是否输出位置信息,true:输出位置信息,false:不输出位置信息,默认false + */ + Json::Value invoiceUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_invoice, null, data, null); + return result; + } + + /** + * 护照识别 + * 支持对图片中的手写中文、手写数字进行检测和识别, + * @param image 二进制图像数据 + * options 可选参数: + + */ + Json::Value passport( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_passport, null, data, null); + return result; + } + + /** + * 护照识别 + * 支持对图片中的手写中文、手写数字进行检测和识别, + * @param url 图片完整URL + * options 可选参数: + + */ + Json::Value passportUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_passport, null, data, null); + return result; + } + + /** + * 网约车行程单识别 + * 对各大主要服务商的网约车行程单进行结构化识别,包括滴滴打车、花小猪打车、高德地图、曹操出行、阳光出行,支持识别服务商、 + 行程开始时间、行程结束时间、车型、总金额等16 个关键字段。 + + * @param image 二进制图像数据 + * options 可选参数: + * pdf_file_num 需要识别的PDF文件的对应页码,当 pdf_file 参数有效时,识别传入页码的对应页面内容,若不传入,则默认识别第 1 页 + */ + Json::Value onlineTaxiItinerary( + std::string image) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_online_taxi_itinerary, null, data, null); + return result; + } + + /** + * 网约车行程单识别 + * 对各大主要服务商的网约车行程单进行结构化识别,包括滴滴打车、花小猪打车、高德地图、曹操出行、阳光出行,支持识别服务商、 + 行程开始时间、行程结束时间、车型、总金额等16 个关键字段。 + + * @param url 图片完整URL路径 + * options 可选参数: + * pdf_file_num 需要识别的PDF文件的对应页码,当 pdf_file 参数有效时,识别传入页码的对应页面内容,若不传入,则默认识别第 1 页 + */ + Json::Value onlineTaxiItineraryUrl( + std::string url) { + std::map data; + data["url"] = url; + Json::Value result = + this->request(_online_taxi_itinerary, null, data, null); + return result; + } + + /** + * 网约车行程单识别 + * 对各大主要服务商的网约车行程单进行结构化识别,包括滴滴打车、花小猪打车、高德地图、曹操出行、阳光出行,支持识别服务商、 + 行程开始时间、行程结束时间、车型、总金额等16 个关键字段。 + + * @param pdf_file pdf文件二进制数据 + * options 可选参数: + * pdf_file_num 需要识别的PDF文件的对应页码,当 pdf_file 参数有效时,识别传入页码的对应页面内容,若不传入,则默认识别第 1 页 + */ + Json::Value onlineTaxiItineraryPdf( + std::string pdf_file, + std::map options) { + std::map data; + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_online_taxi_itinerary, null, data, null); + return result; + } + + /** + * 磅单识别 + * 结构化识别磅单的车牌号、打印时间、毛重、皮重、净重、发货单位、收货单位、单号8个关键字段,现阶段仅支持识别印刷体磅单 + * @param image 二进制图像数据 + * options 可选参数: + * pdf_file_num 需要识别的PDF文件的对应页码,当 pdf_file 参数有效时,识别传入页码的对应页面内容,若不传入,则默认识别第 1 页 * probability 是否返回字段识别结果的置信度,默认为 false,可缺省 + - false:不返回字段识别结果的置信度 + - true:返回字段识别结果的置信度,包括字段识别结果中各字符置信度的平均值(average)和最小值(min) + + */ + Json::Value weightNote( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_weight_note, null, data, null); + return result; + } + + /** + * 磅单识别 + * 结构化识别磅单的车牌号、打印时间、毛重、皮重、净重、发货单位、收货单位、单号8个关键字段,现阶段仅支持识别印刷体磅单 + * @param url 图片完整URL路径 + * options 可选参数: + * pdf_file_num 需要识别的PDF文件的对应页码,当 pdf_file 参数有效时,识别传入页码的对应页面内容,若不传入,则默认识别第 1 页 * probability 是否返回字段识别结果的置信度,默认为 false,可缺省 + - false:不返回字段识别结果的置信度 + - true:返回字段识别结果的置信度,包括字段识别结果中各字符置信度的平均值(average)和最小值(min) + + */ + Json::Value weightNoteUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_weight_note, null, data, null); + return result; + } + + /** + * 磅单识别 + * 结构化识别磅单的车牌号、打印时间、毛重、皮重、净重、发货单位、收货单位、单号8个关键字段,现阶段仅支持识别印刷体磅单 + * @param pdf_file 图片完整URL路径 + * options 可选参数: + * pdf_file_num 需要识别的PDF文件的对应页码,当 pdf_file 参数有效时,识别传入页码的对应页面内容,若不传入,则默认识别第 1 页 * probability 是否返回字段识别结果的置信度,默认为 false,可缺省 + - false:不返回字段识别结果的置信度 + - true:返回字段识别结果的置信度,包括字段识别结果中各字符置信度的平均值(average)和最小值(min) + + */ + Json::Value weightNotePdf( + std::string pdf_file, + std::map options) { + std::map data; + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_weight_note, null, data, null); + return result; + } + + /** + * 医疗费用明细识别 + * 支持识别全国医疗费用明细的姓名、日期、病人ID、总金额等关键字段,支持识别费用明细项目清单,包含项目类型、项目名称、 + 单价、数量、规格、金额,其中北京地区识别效果最佳。 + + * @param image 二进制图像数据 + * options 可选参数: + * location 是否返回字段的位置信息,默认为 false,可缺省 + - false:不返回字段位置信息 + - true:返回字段的位置信息,包括上边距(top)、左边距(left)、宽度(width)、高度(height) + * probability 是否返回字段识别结果的置信度,默认为 false,可缺省 + - false:不返回字段识别结果的置信度 + - true:返回字段识别结果的置信度,包括字段识别结果中各字符置信度的平均值(average)和最小值(min) + + */ + Json::Value medicalDetail( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_detail, null, data, null); + return result; + } + + /** + * 医疗费用明细识别 + * 支持识别全国医疗费用明细的姓名、日期、病人ID、总金额等关键字段,支持识别费用明细项目清单,包含项目类型、项目名称、 + 单价、数量、规格、金额,其中北京地区识别效果最佳。 + + * @param url 图片完整URL路径 + * options 可选参数: + * location 是否返回字段的位置信息,默认为 false,可缺省 + - false:不返回字段位置信息 + - true:返回字段的位置信息,包括上边距(top)、左边距(left)、宽度(width)、高度(height) + * probability 是否返回字段识别结果的置信度,默认为 false,可缺省 + - false:不返回字段识别结果的置信度 + - true:返回字段识别结果的置信度,包括字段识别结果中各字符置信度的平均值(average)和最小值(min) + + */ + Json::Value medicalDetailUrl( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_detail, null, data, null); + return result; + } + + + /** + * 彩票识别 + * @param image 二进制图像数据 + * @param options 可选参数: + * - recognize_granularity: 是否定位单字符位置, + * - big:不定位单字符位置,默认值 + * - small:定位单字符位置 + */ + Json::Value lottery_v1( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_lottery_v1, null, data, null); + return result; + } + + /** + * 彩票识别 + * @param url 图像链接 + * @param options 可选参数: + * - recognize_granularity: 是否定位单字符位置, + * - big:不定位单字符位置,默认值 + * - small:定位单字符位置 + */ + Json::Value lottery_v1_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_lottery_v1, null, data, null); + return result; + } + + /** + * 保险单识别 + * @param image 二进制图像数据 + * @param options 可选参数: + */ + Json::Value insurance_documents_v1( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_insurance_documents_v1, null, data, null); + return result; + } + + /** + * 保险单识别 + * @param url 图像链接 + * @param options 可选参数: + */ + Json::Value insurance_documents_v1_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_insurance_documents_v1, null, data, null); + return result; + } + + /** + * 台湾通行证识别 + * @param image 二进制图像数据 + * @param options 可选参数: + */ + Json::Value taiwan_exitentrypermit_v1( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_taiwan_exitentrypermit_v1, null, data, null); + return result; + } + + /** + * 台湾通行证识别 + * @param url 图像链接 + * @param options 可选参数: + */ + Json::Value taiwan_exitentrypermit_v1_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_taiwan_exitentrypermit_v1, null, data, null); + return result; + } + + /** + * 港澳通行证识别 + * @param image 二进制图像数据 + * @param options 可选参数: + */ + Json::Value hk_macau_exitentrypermit_v1( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_HK_Macau_exitentrypermit_v1, null, data, null); + return result; + } + + /** + * 港澳通行证识别 + * @param url 图像链接 + * @param options 可选参数: + */ + Json::Value hk_macau_exitentrypermit_v1_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_HK_Macau_exitentrypermit_v1, null, data, null); + return result; + } + + + /** + * 出生医学证明识别 + * @param image 二进制图像数据 + * @param options 可选参数: + */ + Json::Value birth_certificate_v1( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_birth_certificate_v1, null, data, null); + return result; + } + + /** + * 出生医学证明识别 + * @param url 图像链接 + * @param options 可选参数: + */ + Json::Value birth_certificate_v1_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_birth_certificate_v1, null, data, null); + return result; + } + + /** + * 名片识别 + * @param image 二进制图像数据 + * @param options 可选参数: + */ + Json::Value business_card_v1( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_business_card_v1, null, data, null); + return result; + } + + /** + * 名片识别 + * @param url 图像链接 + * @param options 可选参数: + */ + Json::Value business_card_v1_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_business_card_v1, null, data, null); + return result; + } + + /** + * 定额发票识别 + * @param image 二进制图像数据 + * @param options 可选参数: + */ + Json::Value quota_invoice_v1( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_quota_invoice_v1, null, data, null); + return result; + } + + /** + * 定额发票识别 + * @param url 图像链接 + * @param options 可选参数: + */ + Json::Value quota_invoice_v1_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_quota_invoice_v1, null, data, null); + return result; + } + + /** + * 定额发票识别 + * @param pdf_file 发票pdf文件二进制数据 + * options 可选参数: + */ + Json::Value quota_invoice_v1_pdf( + std::string const &pdf_file, + const std::map &options) { + std::map data; + + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_quota_invoice_v1, null, data, null); + + return result; + } + + /** + * iOCR 通用版 + * @param image 二进制图像数据 + * @param options 可选参数: + */ + Json::Value recognise_iocr_v1( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_recognise_iocr_v1, null, data, null); + return result; + } + + /** + * iOCR 通用版 + * @param url 图像链接 + * @param options 可选参数: + */ + Json::Value recognise_iocr_v1_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_recognise_iocr_v1, null, data, null); + return result; + } + + /** + * iOCR 通用版 + * @param pdf_file 发票pdf文件二进制数据 + * options 可选参数: + */ + Json::Value recognise_iocr_v1_pdf( + std::string const &pdf_file, + const std::map &options) { + std::map data; + + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_recognise_iocr_v1, null, data, null); + + return result; + } + + /** + * iOCR自定义模板文字识别 - 财会版 + * 参数详情参考:https://cloud.baidu.com/doc/OCR/s/yk3h7y9u3 + */ + Json::Value custom_finance( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_recognise_iocr_finance, null, data, null); + return result; + } + + /** + * iOCR自定义模板文字识别 - 财会版 + * 参数详情参考:https://cloud.baidu.com/doc/OCR/s/yk3h7y9u3 + */ + Json::Value custom_finance_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_recognise_iocr_finance, null, data, null); + return result; + } + + /** + * iOCR自定义模板文字识别 - 财会版 + * 参数详情参考:https://cloud.baidu.com/doc/OCR/s/yk3h7y9u3 + */ + Json::Value custom_finance_pdf( + std::string const &pdf_file, + const std::map &options) { + std::map data; + + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_recognise_iocr_finance, null, data, null); + + return result; + } + + /** + * 汽车票识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Kkblx01ww + */ + Json::Value bus_ticket( + std::string image) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_bus_ticket, null, data, null); + return result; + } + + /** + * 汽车票识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Kkblx01ww + */ + Json::Value bus_ticket_url( + std::string url) { + std::map data; + data["url"] = url; + Json::Value result = + this->request(_bus_ticket, null, data, null); + return result; + } + + /** + * 过路过桥费发票识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/1kbpyx8js + */ + Json::Value toll_invoice( + std::string image) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_toll_invoice, null, data, null); + return result; + } + + /** + * 过路过桥费发票识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/1kbpyx8js + */ + Json::Value toll_invoice_url( + std::string url) { + std::map data; + data["url"] = url; + Json::Value result = + this->request(_toll_invoice, null, data, null); + return result; + } + + /** + * 多卡证类别检测 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/nkbq6wxxy + */ + Json::Value multi_card_classify( + std::string image) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_multi_card_classify, null, data, null); + return result; + } + + /** + * 多卡证类别检测 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/nkbq6wxxy + */ + Json::Value multi_card_classify_url( + std::string url) { + std::map data; + data["url"] = url; + Json::Value result = + this->request(_multi_card_classify, null, data, null); + return result; + } + + /** + * 智能结构化识别 - image (注:已下线) + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Qke3nkykj + */ + Json::Value intelligent_ocr( + std::string image) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_intelligent_ocr, null, data, null); + return result; + } + + /** + * 智能结构化识别 - url (注:已下线) + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Qke3nkykj + */ + Json::Value intelligent_ocr_url( + std::string url) { + std::map data; + data["url"] = url; + Json::Value result = + this->request(_intelligent_ocr, null, data, null); + return result; + } + + /** + * 病案首页识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/1ke30k2s2 + */ + Json::Value medical_record( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_record, null, data, null); + return result; + } + + /** + * 病案首页识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/1ke30k2s2 + */ + Json::Value medical_record_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_record, null, data, null); + return result; + } + + /** + * 医疗费用结算单识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Jke30ki7d + */ + Json::Value medical_statement( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_statement, null, data, null); + return result; + } + + /** + * 医疗费用结算单识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Jke30ki7d + */ + Json::Value medical_statement_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_statement, null, data, null); + return result; + } + + /** + * 船票识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/nkmcwp3ne + */ + Json::Value ferry_ticket( + std::string image) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_ferry_ticket, null, data, null); + return result; + } + + /** + * 船票识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/nkmcwp3ne + */ + Json::Value ferry_ticket_url( + std::string url) { + std::map data; + data["url"] = url; + Json::Value result = + this->request(_ferry_ticket, null, data, null); + return result; + } + + /** + * 二手车销售发票识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/8knr8rrj8 + */ + Json::Value used_vehicle_invoice( + std::string image) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_used_vehicle_invoice, null, data, null); + return result; + } + + /** + * 二手车销售发票识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/8knr8rrj8 + */ + Json::Value used_vehicle_invoice_url( + std::string url) { + std::map data; + data["url"] = url; + Json::Value result = + this->request(_used_vehicle_invoice, null, data, null); + return result; + } + + /** + * 身份证混贴识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/akp3gfbmc + */ + Json::Value multi_idcard( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_multi_idcard, null, data, null); + return result; + } + + /** + * 身份证混贴识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/akp3gfbmc + */ + Json::Value multi_idcard_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_multi_idcard, null, data, null); + return result; + } + + /** + * 通信行程卡识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Nksg89dkc + */ + Json::Value travel_card( + std::string image) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_travel_card, null, data, null); + return result; + } + + /** + * 社保卡识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/lkto93055 + */ + Json::Value social_security_card( + std::string image) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_social_security_card, null, data, null); + return result; + } + + /** + * 社保卡识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/lkto93055 + */ + Json::Value social_security_card_url( + std::string url) { + std::map data; + data["url"] = url; + Json::Value result = + this->request(_social_security_card, null, data, null); + return result; + } + + /** + * 医疗检验报告单识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Ekvakju92 + */ + Json::Value medical_report_detection( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_report_detection, null, data, null); + return result; + } + + /** + * 医疗检验报告单识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Ekvakju92 + */ + Json::Value medical_report_detection_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_report_detection, null, data, null); + return result; + } + + /** + * 医疗票据类别检测 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/zkvriu3sh + */ + Json::Value medical_recipts_classify( + std::string image) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_medical_recipts_classify, null, data, null); + return result; + } + + /** + * 医疗票据类别检测 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/zkvriu3sh + */ + Json::Value medical_recipts_classify_url( + std::string url) { + std::map data; + data["url"] = url; + Json::Value result = + this->request(_medical_recipts_classify, null, data, null); + return result; + } + + /** + * 快递面单识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Ekwkggqa5 + */ + Json::Value waybill( + std::string image) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_waybill, null, data, null); + return result; + } + + /** + * 快递面单识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Ekwkggqa5 + */ + Json::Value waybill_url( + std::string url) { + std::map data; + data["url"] = url; + Json::Value result = + this->request(_waybill, null, data, null); + return result; + } + + /** + * 出院小结识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Wkwwy4y4q + */ + Json::Value medical_summary( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_summary, null, data, null); + return result; + } + + /** + * 出院小结识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Wkwwy4y4q + */ + Json::Value medical_summary_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_summary, null, data, null); + return result; + } + + /** + * 购物小票识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/3kwvk8y36 + */ + Json::Value shopping_receipt( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_shopping_receipt, null, data, null); + return result; + } + + /** + * 购物小票识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/3kwvk8y36 + */ + Json::Value shopping_receipt_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_shopping_receipt, null, data, null); + return result; + } + + /** + * 购物小票识别 - pdf + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/3kwvk8y36 + */ + Json::Value shopping_receipt_pdf( + std::string const &pdf_file, + const std::map &options) { + std::map data; + + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_shopping_receipt, null, data, null); + + return result; + } + + /** + * 道路运输证识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/ol07rjylw + */ + Json::Value road_transport_certificate( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_road_transport_certificate, null, data, null); + return result; + } + + /** + * 道路运输证识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/ol07rjylw + */ + Json::Value road_transport_certificate_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_road_transport_certificate, null, data, null); + return result; + } + + /** + * 道路运输证识别 - pdf + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/ol07rjylw + */ + Json::Value road_transport_certificate_pdf( + std::string const &pdf_file, + const std::map &options) { + std::map data; + + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_road_transport_certificate, null, data, null); + + return result; + } + + /** + * 表格文字识别V2 - image + * 参数详情参考:https://cloud.baidu.com/doc/OCR/s/yk3h7y9u3 + */ + Json::Value table( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_table, null, data, null); + return result; + } + + /** + * 表格文字识别V2 - url + * 参数详情参考:https://cloud.baidu.com/doc/OCR/s/yk3h7y9u3 + */ + Json::Value table_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_table, null, data, null); + return result; + } + + /** + * 表格文字识别V2 - pdf + * 参数详情参考:https://cloud.baidu.com/doc/OCR/s/yk3h7y9u3 + */ + Json::Value table_pdf( + std::string const &pdf_file, + const std::map &options) { + std::map data; + + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_table, null, data, null); + + return result; + } + + /** + * 文档去手写 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/il4tb1jay + */ + Json::Value remove_handwriting( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_remove_handwriting, null, data, null); + return result; + } + + /** + * 文档去手写 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/il4tb1jay + */ + Json::Value remove_handwriting_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_remove_handwriting, null, data, null); + return result; + } + + /** + * 文档去手写 - pdf + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/il4tb1jay + */ + Json::Value remove_handwriting_pdf( + std::string const &pdf_file, + const std::map &options) { + std::map data; + + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_remove_handwriting, null, data, null); + + return result; + } + + /** + * 文档矫正增强 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Hl4taza5f + */ + Json::Value doc_crop_enhance( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_doc_crop_enhance, null, data, null); + return result; + } + + /** + * 文档矫正增强 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Hl4taza5f + */ + Json::Value doc_crop_enhance_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_doc_crop_enhance, null, data, null); + return result; + } + + /** + * 文档矫正增强 - pdf + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Hl4taza5f + */ + Json::Value doc_crop_enhance_pdf( + std::string const &pdf_file, + const std::map &options) { + std::map data; + + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_doc_crop_enhance, null, data, null); + + return result; + } + + /** + * 健康码识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Ol52hedan + */ + Json::Value health_code( + std::string image) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_health_code, null, data, null); + return result; + } + + /** + * 核酸证明识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/nl56ibk44 + */ + Json::Value covid_test( + std::string image) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + Json::Value result = + this->request(_covid_test, null, data, null); + return result; + } + + /** + * 处方笺识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Pl59exph0 + */ + Json::Value medical_prescription( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_prescription, null, data, null); + return result; + } + + /** + * 处方笺识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Pl59exph0 + */ + Json::Value medical_prescription_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_prescription, null, data, null); + return result; + } + + /** + * 门诊病历识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/ll59eepzw + */ + Json::Value medical_outpatient( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_outpatient, null, data, null); + return result; + } + + /** + * 门诊病历识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/ll59eepzw + */ + Json::Value medical_outpatient_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_outpatient, null, data, null); + return result; + } + + /** + * 诊断证明识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Dl59e3ohe + */ + Json::Value medical_summary_diagnosis( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_summary_diagnosis, null, data, null); + return result; + } + + /** + * 诊断证明识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Dl59e3ohe + */ + Json::Value medical_summary_diagnosis_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_medical_summary_diagnosis, null, data, null); + return result; + } + + /** + * 医疗诊断报告单识别 - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/El59es47z + */ + Json::Value health_report( + std::string image, + std::map options) { + std::map data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_health_report, null, data, null); + return result; + } + + /** + * 医疗诊断报告单识别 - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/El59es47z + */ + Json::Value health_report_url( + std::string url, + std::map options) { + std::map data; + data["url"] = url; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + Json::Value result = + this->request(_health_report, null, data, null); + return result; + } + + /** + * 图文转换器(接口版) - image + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Elf3sp7cz + */ + Json::Value doc_convert_request_v1( + std::string const &image, + const Json::Value &options) { + Json::Value data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_doc_convert_request_v1, data, &headers); + return result; + } + + /** + * 图文转换器(接口版) - url + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Elf3sp7cz + */ + Json::Value doc_convert_request_v1_url( + std::string const &url, + const Json::Value &options) { + Json::Value data; + data["url"] = url; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_doc_convert_request_v1, data, &headers); + return result; + } + + /** + * 图文转换器(接口版) - pdf + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Elf3sp7cz + */ + Json::Value doc_convert_request_v1_pdf( + std::string const &pdf_file, + const Json::Value &options) { + Json::Value data; + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_doc_convert_request_v1, data, &headers); + return result; + } + + /** + * 图文转换器(接口版) - 获取结果 + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Elf3sp7cz + */ + Json::Value doc_convert_result_v1(std::string const &task_id) { + Json::Value data; + data["task_id"] = task_id; + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_doc_convert_result_v1, data, &headers); + return result; + } + + /** + * 银行回单识别 + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Plep1yzi9 + */ + Json::Value bank_receipt_new( + std::string const &image, + const Json::Value &options) { + Json::Value data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_bank_receipt_new, data, &headers); + return result; + } + + /** + * 银行回单识别 + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Plep1yzi9 + */ + Json::Value bank_receipt_new_url( + std::string const &url, + const Json::Value &options) { + Json::Value data; + data["url"] = url; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_bank_receipt_new, data, &headers); + return result; + } + + /** + * 银行回单识别 + * 参数详情参考:https://ai.baidu.com/ai-doc/OCR/Plep1yzi9 + */ + Json::Value bank_receipt_new_pdf( + std::string const &pdf_file, + const Json::Value &options) { + Json::Value data; + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_bank_receipt_new, data, &headers); + return result; + } + + /** + * 结婚证识别 + * 参数详情参考: https://ai.baidu.com/ai-doc/OCR/Klg67mfkc + */ + Json::Value marriage_certificate( + std::string const & image, + const Json::Value & options) + { + Json::Value data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_marriage_certificate, data, &headers); + return result; + } + + /** + * 结婚证识别 + * 参数详情参考: https://ai.baidu.com/ai-doc/OCR/Klg67mfkc + */ + Json::Value marriage_certificate_url( + std::string const & url, + const Json::Value & options) + { + Json::Value data; + data["url"] = url; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_marriage_certificate, data, &headers); + return result; + } + /** + * 结婚证识别 + * 参数详情参考: https://ai.baidu.com/ai-doc/OCR/Klg67mfkc + */ + Json::Value marriage_certificate_pdf( + std::string const & pdf_file, + const Json::Value & options) + { + Json::Value data; + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_marriage_certificate, data, &headers); + return result; + } + + /** + * 港澳台证件识别 + * 参数详情参考: https://ai.baidu.com/ai-doc/OCR/Tlg6859ns + */ + Json::Value hk_macau_taiwan_exitentrypermit( + std::string const & image, + std::string const & exitentrypermit_type, + const Json::Value & options) + { + Json::Value data; + data["image"] = base64_encode(image.c_str(), (int) image.size()); + data["exitentrypermit_type"] = exitentrypermit_type; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_hk_macau_taiwan_exitentrypermit, data, &headers); + return result; + } + + /** + * 港澳台证件识别 + * 参数详情参考: https://ai.baidu.com/ai-doc/OCR/Tlg6859ns + */ + Json::Value hk_macau_taiwan_exitentrypermit_url( + std::string const & url, + std::string const & exitentrypermit_type, + const Json::Value & options) + { + Json::Value data; + data["url"] = url; + data["exitentrypermit_type"] = exitentrypermit_type; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_hk_macau_taiwan_exitentrypermit, data, &headers); + return result; + } + /** + * 港澳台证件识别 + * 参数详情参考: https://ai.baidu.com/ai-doc/OCR/Tlg6859ns + */ + Json::Value hk_macau_taiwan_exitentrypermit_pdf( + std::string const & pdf_file, + std::string const & exitentrypermit_type, + const Json::Value & options) + { + Json::Value data; + data["pdf_file"] = base64_encode(pdf_file.c_str(), (int) pdf_file.size()); + data["exitentrypermit_type"] = exitentrypermit_type; + merge_json(data, options); + + std::map headers; + headers["Content-Type"] = "application/x-www-form-urlencoded"; + Json::Value result = this->request_com(_hk_macau_taiwan_exitentrypermit, data, &headers); + return result; + } + }; +} +#endif diff --git a/third/include/aip-cpp-sdk/speech.h b/third/include/aip-cpp-sdk/speech.h new file mode 100644 index 0000000..c3e657a --- /dev/null +++ b/third/include/aip-cpp-sdk/speech.h @@ -0,0 +1,135 @@ +#ifndef __AIP_SPEECH_H__ +#define __AIP_SPEECH_H__ + +#include "base/base.h" + +#include + +namespace aip { + + class Speech : public AipBase { + public: + + std::string _asr = "https://vop.baidu.com/server_api"; + + std::string _tts = "http://tsn.baidu.com/text2audio"; + + Speech(const std::string app_id, const std::string &ak, const std::string &sk) : AipBase(app_id, ak, sk) { + } + + Json::Value request_asr( + std::string const &url, + Json::Value &data) { + std::string response; + Json::Value obj; + int status_code = this->client.post(url, nullptr, data, nullptr, &response); + + if (status_code != CURLcode::CURLE_OK) { + obj[aip::CURL_ERROR_CODE] = status_code; + return obj; + } + + std::string error; + std::unique_ptr reader(crbuilder.newCharReader()); + reader->parse(response.data(), response.data() + response.size(), &obj, &error); + + return obj; + } + + Json::Value request_tts( + const std::string url, + std::map &data, + std::string &file_content) { + std::string response; + Json::Value obj; + Json::Value file_json; + int status_code = this->client.post(url, nullptr, data, nullptr, &response); + if (status_code != CURLcode::CURLE_OK) { + obj[aip::CURL_ERROR_CODE] = status_code; + return obj; + } + + file_content = response; + + return obj; + } + + Json::Value recognize(const std::string voice_binary, const std::string &format, const int &rate, + std::map const &options) { + Json::Value data; + + std::map::const_iterator it; + for (it = options.begin(); it != options.end(); it++) { + data[it->first] = it->second; + if (it->first == "dev_pid") { + data[it->first] = atoi(it->second.c_str()); + } + } + + std::string token = this->getAccessToken(); + + data["speech"] = base64_encode(voice_binary.c_str(), (int) voice_binary.size()); + data["format"] = format; + data["rate"] = std::to_string(rate); + data["channel"] = "1"; + data["token"] = token; + data["cuid"] = this->getAk(); + data["len"] = (int) voice_binary.size(); + + Json::Value result = this->request_asr(_asr, data); + return result; + } + + Json::Value recognize_url(const std::string &url, + const std::string &callback, const std::string &format, + const int &rate, + std::map options) { + Json::Value data; + std::map::iterator it; + + for (it = options.begin(); it != options.end(); it++) { + data[it->first] = it->second; + if (it->first == "dev_pid") { + data[it->first] = atoi(it->second.c_str()); + } + } + + std::string token = this->getAccessToken(); + + data["url"] = url; + data["callback"] = callback; + data["format"] = format; + data["rate"] = std::to_string(rate); + data["channel"] = 1; + data["token"] = token; + data["cuid"] = this->getAk(); + + Json::Value result = this->request_asr(_asr, data); + return result; + } + + Json::Value text2audio(const std::string &text, std::map const &options, + std::string &file_content) { + std::map data; + std::map::const_iterator it; + + for (it = options.begin(); it != options.end(); it++) { + data[it->first] = it->second; + } + + std::string token = this->getAccessToken(); + + data["tex"] = text; + data["lan"] = "zh"; + data["ctp"] = "1"; + data["tok"] = token; + data["cuid"] = this->getAk(); + + Json::Value result = this->request_tts(_tts, data, file_content); + return result; + } + + }; + +} +#endif diff --git a/third/include/aip-cpp-sdk/video_censor.h b/third/include/aip-cpp-sdk/video_censor.h new file mode 100644 index 0000000..f4e5159 --- /dev/null +++ b/third/include/aip-cpp-sdk/video_censor.h @@ -0,0 +1,63 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ + +#ifndef __AIP_VIDEOCENSOR_H__ +#define __AIP_VIDEOCENSOR_H__ + +#include "base/base.h" + +namespace aip { + + class Videocensor: public AipBase + { + public: + + + std::string _video_url = + "https://aip.baidubce.com/rest/2.0/solution/v1/video_censor/v2/user_defined"; + + + Videocensor(const std::string & app_id, const std::string & ak, const std::string & sk): AipBase(app_id, ak, sk) + { + } + + /** + * voice_censor + * 本接口除了支持自定义配置外,还对返回结果进行了总体的包装,按照用户在控制台中配置的规则直接返回是否合规,如果不合规则指出具体不合规的内容。 + * @param voice 语音文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value video_censor( + std::string const & name, + std::string const & url, + std::string const & extId, + const std::map & options) + { + std::map data; + + data["videoUrl"] = url; + data["name"] = name; + data["extId"] = extId; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_video_url, null, data, null); + + return result; + } + + }; +} +#endif diff --git a/third/include/aip-cpp-sdk/voice_censor.h b/third/include/aip-cpp-sdk/voice_censor.h new file mode 100644 index 0000000..1517dc4 --- /dev/null +++ b/third/include/aip-cpp-sdk/voice_censor.h @@ -0,0 +1,87 @@ +/** + * Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * @author baidu aip + */ + +#ifndef __AIP_IMAGECENSOR_H__ +#define __AIP_IMAGECENSOR_H__ + +#include "base/base.h" + +namespace aip { + + class Voicecensor: public AipBase + { + public: + + + std::string _voice_url = + "https://aip.baidubce.com/rest/2.0/solution/v1/voice_censor/v3/user_defined"; + + + Voicecensor(const std::string & app_id, const std::string & ak, const std::string & sk): AipBase(app_id, ak, sk) + { + } + + /** + * voice_censor + * 本接口除了支持自定义配置外,还对返回结果进行了总体的包装,按照用户在控制台中配置的规则直接返回是否合规,如果不合规则指出具体不合规的内容。 + * @param voice 语音文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value voice_censor( + std::string const & voice, + std::int32_t const & rate, + std::string const & fmt, + const std::map & options) + { + std::map data; + + data["base64"] = base64_encode(voice.c_str(), (int) voice.size()); + data["fmt"] = fmt; + data["rate"] = rate; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_voice_url, null, data, null); + + return result; + } + + /** + * voice_censor + * 本接口除了支持自定义配置外,还对返回结果进行了总体的包装,按照用户在控制台中配置的规则直接返回是否合规,如果不合规则指出具体不合规的内容。 + * @param voice 语音文件二进制内容,可以使用aip::get_file_content函数获取 + * options 可选参数: + */ + Json::Value voice_censorUrl( + std::string const & url, + std::int32_t const & rate, + std::string const & fmt, + const std::map & options) + { + std::map data; + + data["url"] = url; + data["fmt"] = fmt; + data["rate"] = rate; + std::copy(options.begin(), options.end(), std::inserter(data, data.end())); + + Json::Value result = + this->request(_voice_url, null, data, null); + + return result; + } + }; +} +#endif diff --git a/transmite/CMakeLists.txt b/transmite/CMakeLists.txt new file mode 100644 index 0000000..f4ee930 --- /dev/null +++ b/transmite/CMakeLists.txt @@ -0,0 +1,92 @@ +# 1. 添加cmake版本说明 +cmake_minimum_required(VERSION 3.1.3) +# 2. 声明工程名称 +project(transmite_server) + +set(target "transmite_server") + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") + +# 3. 检测并生成ODB框架代码 +# 1. 添加所需的proto映射代码文件名称 +set(proto_path ${CMAKE_CURRENT_SOURCE_DIR}/../proto) +set(proto_files base.proto user.proto transmite.proto) +# 2. 检测框架代码文件是否已经生成 +set(proto_hxx "") +set(proto_cxx "") +set(proto_srcs "") +foreach(proto_file ${proto_files}) +# 3. 如果没有生成,则预定义生成指令 -- 用于在构建项目之间先生成框架代码 + string(REPLACE ".proto" ".pb.cc" proto_cc ${proto_file}) + string(REPLACE ".proto" ".pb.h" proto_hh ${proto_file}) + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}${proto_cc}) + add_custom_command( + PRE_BUILD + COMMAND protoc + ARGS --cpp_out=${CMAKE_CURRENT_BINARY_DIR} -I ${proto_path} --experimental_allow_proto3_optional ${proto_path}/${proto_file} + DEPENDS ${proto_path}/${proto_file} + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + COMMENT "生成Protobuf框架代码文件:" ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + ) + endif() + list(APPEND proto_srcs ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc}) +endforeach() + +# 3. 检测并生成ODB框架代码 +# 1. 添加所需的odb映射代码文件名称 +set(odb_path ${CMAKE_CURRENT_SOURCE_DIR}/../odb) +set(odb_files chat_session_member.hxx) +# 2. 检测框架代码文件是否已经生成 +set(odb_hxx "") +set(odb_cxx "") +set(odb_srcs "") +foreach(odb_file ${odb_files}) +# 3. 如果没有生成,则预定义生成指令 -- 用于在构建项目之间先生成框架代码 + string(REPLACE ".hxx" "-odb.hxx" odb_hxx ${odb_file}) + string(REPLACE ".hxx" "-odb.cxx" odb_cxx ${odb_file}) + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}${odb_cxx}) + add_custom_command( + PRE_BUILD + COMMAND odb + ARGS -d mysql --std c++11 --generate-query --generate-schema --profile boost/date-time ${odb_path}/${odb_file} + DEPENDS ${odb_path}/${odb_file} + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${odb_cxx} + COMMENT "生成ODB框架代码文件:" ${CMAKE_CURRENT_BINARY_DIR}/${odb_cxx} + ) + endif() +# 4. 将所有生成的框架源码文件名称保存起来 student-odb.cxx classes-odb.cxx + list(APPEND odb_srcs ${CMAKE_CURRENT_BINARY_DIR}/${odb_cxx}) +endforeach() + +# 4. 获取源码目录下的所有源码文件 +set(src_files "") +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/source src_files) +# 5. 声明目标及依赖 +add_executable(${target} ${src_files} ${proto_srcs} ${odb_srcs}) +# 7. 设置需要连接的库 +target_link_libraries(${target} -lgflags + -lspdlog -lfmt -lbrpc -lssl -lcrypto + -lprotobuf -lleveldb -letcd-cpp-api + -lcpprest -lcurl -lodb-mysql -lodb -lodb-boost + -lamqpcpp -lev) + + +set(trans_user_client "trans_user_client") +set(trans_user_files ${CMAKE_CURRENT_SOURCE_DIR}/test/user_client.cc) +add_executable(${trans_user_client} ${trans_user_files} ${proto_srcs}) +target_link_libraries(${trans_user_client} -pthread -lgtest -lgflags -lspdlog -lfmt -lbrpc -lssl -lcrypto -lprotobuf -lleveldb -letcd-cpp-api -lcpprest -lcurl /usr/lib/x86_64-linux-gnu/libjsoncpp.so.19) + +set(transmite_client "transmite_client") +set(transmite_files ${CMAKE_CURRENT_SOURCE_DIR}/test/transmite_client.cc) +add_executable(${transmite_client} ${transmite_files} ${proto_srcs}) +target_link_libraries(${transmite_client} -pthread -lgtest -lgflags -lspdlog -lfmt -lbrpc -lssl -lcrypto -lprotobuf -lleveldb -letcd-cpp-api -lcpprest -lcurl /usr/lib/x86_64-linux-gnu/libjsoncpp.so.19) + + +# 6. 设置头文件默认搜索路径 +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../common) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../odb) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third/include) + +#8. 设置安装路径 +INSTALL(TARGETS ${target} ${trans_user_client} ${transmite_client} RUNTIME DESTINATION bin) \ No newline at end of file diff --git a/transmite/dockerfile b/transmite/dockerfile new file mode 100644 index 0000000..172ab21 --- /dev/null +++ b/transmite/dockerfile @@ -0,0 +1,16 @@ +# 声明基础经镜像来源 +FROM debian:12 + +# 声明工作目录 +WORKDIR /im +RUN mkdir -p /im/logs &&\ + mkdir -p /im/data &&\ + mkdir -p /im/conf &&\ + mkdir -p /im/bin + +# 将可执行程序依赖,拷贝进镜像 +COPY ./build/transmite_server /im/bin/ +# 将可执行程序文件,拷贝进镜像 +COPY ./depends /lib/x86_64-linux-gnu/ +# 设置容器启动的默认操作 ---运行程序 +CMD /im/bin/transmite_server -flagfile=/im/conf/transmite_server.conf \ No newline at end of file diff --git a/transmite/source/transmite_server.cc b/transmite/source/transmite_server.cc new file mode 100644 index 0000000..0896b90 --- /dev/null +++ b/transmite/source/transmite_server.cc @@ -0,0 +1,51 @@ +//主要实现语音识别子服务的服务器的搭建 +#include "transmite_server.hpp" + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +DEFINE_string(registry_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(instance_name, "/transmite_service/instance", "当前实例名称"); +DEFINE_string(access_host, "127.0.0.1:10004", "当前实例的外部访问地址"); + +DEFINE_int32(listen_port, 10004, "Rpc服务器监听端口"); +DEFINE_int32(rpc_timeout, -1, "Rpc调用超时时间"); +DEFINE_int32(rpc_threads, 1, "Rpc的IO线程数量"); + +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(user_service, "/service/user_service", "用户管理子服务名称"); + +DEFINE_string(mysql_host, "127.0.0.1", "Mysql服务器访问地址"); +DEFINE_string(mysql_user, "root", "Mysql服务器访问用户名"); +DEFINE_string(mysql_pswd, "123456", "Mysql服务器访问密码"); +DEFINE_string(mysql_db, "bite_im", "Mysql默认库名称"); +DEFINE_string(mysql_cset, "utf8", "Mysql客户端字符集"); +DEFINE_int32(mysql_port, 0, "Mysql服务器访问端口"); +DEFINE_int32(mysql_pool_count, 4, "Mysql连接池最大连接数量"); + +DEFINE_string(mq_user, "root", "消息队列服务器访问用户名"); +DEFINE_string(mq_pswd, "123456", "消息队列服务器访问密码"); +DEFINE_string(mq_host, "127.0.0.1:5672", "消息队列服务器访问地址"); +DEFINE_string(mq_msg_exchange, "msg_exchange", "持久化消息的发布交换机名称"); +DEFINE_string(mq_msg_queue, "msg_queue", "持久化消息的发布队列名称"); +DEFINE_string(mq_msg_binding_key, "msg_queue", "持久化消息的发布队列名称"); + + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + bite_im::TransmiteServerBuilder tsb; + tsb.make_mq_object(FLAGS_mq_user, FLAGS_mq_pswd, FLAGS_mq_host, + FLAGS_mq_msg_exchange, FLAGS_mq_msg_queue, FLAGS_mq_msg_binding_key); + tsb.make_mysql_object(FLAGS_mysql_user, FLAGS_mysql_pswd, FLAGS_mysql_host, + FLAGS_mysql_db, FLAGS_mysql_cset, FLAGS_mysql_port, FLAGS_mysql_pool_count); + tsb.make_discovery_object(FLAGS_registry_host, FLAGS_base_service, FLAGS_user_service); + tsb.make_rpc_server(FLAGS_listen_port, FLAGS_rpc_timeout, FLAGS_rpc_threads); + tsb.make_registry_object(FLAGS_registry_host, FLAGS_base_service + FLAGS_instance_name, FLAGS_access_host); + auto server = tsb.build(); + server->start(); + return 0; +} \ No newline at end of file diff --git a/transmite/source/transmite_server.hpp b/transmite/source/transmite_server.hpp new file mode 100644 index 0000000..b2d8c71 --- /dev/null +++ b/transmite/source/transmite_server.hpp @@ -0,0 +1,235 @@ +//实现语音识别子服务 +#include +#include + +#include "etcd.hpp" // 服务注册模块封装 +#include "logger.hpp" // 日志模块封装 +#include "rabbitmq.hpp" +#include "channel.hpp" +#include "utils.hpp" +#include "mysql_chat_session_member.hpp" + +#include "base.pb.h" // protobuf框架代码 +#include "user.pb.h" // protobuf框架代码 +#include "transmite.pb.h" // protobuf框架代码 + +namespace bite_im{ +class TransmiteServiceImpl : public bite_im::MsgTransmitService { + public: + TransmiteServiceImpl(const std::string &user_service_name, + const ServiceManager::ptr &channels, + const std::shared_ptr &mysql_client, + const std::string &exchange_name, + const std::string &routing_key, + const MQClient::ptr &mq_client): + _user_service_name(user_service_name), + _mm_channels(channels), + _mysql_session_member_table(std::make_shared(mysql_client)), + _exchange_name(exchange_name), + _routing_key(routing_key), + _mq_client(mq_client){} + ~TransmiteServiceImpl(){} + void GetTransmitTarget(google::protobuf::RpcController* controller, + const ::bite_im::NewMessageReq* request, + ::bite_im::GetTransmitTargetRsp* response, + ::google::protobuf::Closure* done) override { + brpc::ClosureGuard rpc_guard(done); + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //从请求中获取关键信息:用户ID,所属会话ID,消息内容 + std::string rid = request->request_id(); + std::string uid = request->user_id(); + std::string chat_ssid = request->chat_session_id(); + const MessageContent &content = request->message(); + // 进行消息组织:发送者-用户子服务获取信息,所属会话,消息内容,产生时间,消息ID + auto channel = _mm_channels->choose(_user_service_name); + if (!channel) { + LOG_ERROR("{}-{} 没有可供访问的用户子服务节点!", rid, _user_service_name); + return err_response(rid, "没有可供访问的用户子服务节点!"); + } + UserService_Stub stub(channel.get()); + GetUserInfoReq req; + GetUserInfoRsp rsp; + req.set_request_id(rid); + req.set_user_id(uid); + brpc::Controller cntl; + stub.GetUserInfo(&cntl, &req, &rsp, nullptr); + if (cntl.Failed() == true || rsp.success() == false) { + LOG_ERROR("{} - 用户子服务调用失败:{}!", request->request_id(), cntl.ErrorText()); + return err_response(request->request_id(), "用户子服务调用失败!"); + } + MessageInfo message; + message.set_message_id(uuid()); + message.set_chat_session_id(chat_ssid); + message.set_timestamp(time(nullptr)); + message.mutable_sender()->CopyFrom(rsp.user_info()); + message.mutable_message()->CopyFrom(content); + // 获取消息转发客户端用户列表 + auto target_list = _mysql_session_member_table->members(chat_ssid); + // 将封装完毕的消息,发布到消息队列,待消息存储子服务进行消息持久化 + bool ret = _mq_client->publish(_exchange_name, message.SerializeAsString(), _routing_key); + if (ret == false) { + LOG_ERROR("{} - 持久化消息发布失败:{}!", request->request_id(), cntl.ErrorText()); + return err_response(request->request_id(), "持久化消息发布失败:!"); + } + //组织响应 + response->set_request_id(rid); + response->set_success(true); + response->mutable_message()->CopyFrom(message); + for (const auto &id : target_list) { + response->add_target_id_list(id); + } + } + private: + //用户子服务调用相关信息 + std::string _user_service_name; + ServiceManager::ptr _mm_channels; + + //聊天会话成员表的操作句柄 + ChatSessionMemeberTable::ptr _mysql_session_member_table; + + //消息队列客户端句柄 + std::string _exchange_name; + std::string _routing_key; + MQClient::ptr _mq_client; +}; + +class TransmiteServer { + public: + using ptr = std::shared_ptr; + TransmiteServer( + const std::shared_ptr &mysql_client, + const Discovery::ptr discovery_client, + const Registry::ptr ®istry_client, + const std::shared_ptr &server): + _service_discoverer(discovery_client), + _registry_client(registry_client), + _mysql_client(mysql_client), + _rpc_server(server){} + ~TransmiteServer(){} + //搭建RPC服务器,并启动服务器 + void start() { + _rpc_server->RunUntilAskedToQuit(); + } + private: + Discovery::ptr _service_discoverer; //服务发现客户端 + Registry::ptr _registry_client; // 服务注册客户端 + std::shared_ptr _mysql_client; //mysql数据库客户端 + std::shared_ptr _rpc_server; +}; + +class TransmiteServerBuilder { + public: + //构造mysql客户端对象 + void make_mysql_object( + const std::string &user, + const std::string &pswd, + const std::string &host, + const std::string &db, + const std::string &cset, + int port, + int conn_pool_count) { + _mysql_client = ODBFactory::create(user, pswd, host, db, cset, port, conn_pool_count); + } + //用于构造服务发现客户端&信道管理对象 + void make_discovery_object(const std::string ®_host, + const std::string &base_service_name, + const std::string &user_service_name) { + _user_service_name = user_service_name; + _mm_channels = std::make_shared(); + _mm_channels->declared(user_service_name); + LOG_DEBUG("设置用户子服务为需添加管理的子服务:{}", user_service_name); + auto put_cb = std::bind(&ServiceManager::onServiceOnline, _mm_channels.get(), std::placeholders::_1, std::placeholders::_2); + auto del_cb = std::bind(&ServiceManager::onServiceOffline, _mm_channels.get(), std::placeholders::_1, std::placeholders::_2); + _service_discoverer = std::make_shared(reg_host, base_service_name, put_cb, del_cb); + } + //用于构造服务注册客户端对象 + void make_registry_object(const std::string ®_host, + const std::string &service_name, + const std::string &access_host) { + _registry_client = std::make_shared(reg_host); + _registry_client->registry(service_name, access_host); + } + //用于构造rabbitmq客户端对象 + void make_mq_object(const std::string &user, + const std::string &passwd, + const std::string &host, + const std::string &exchange_name, + const std::string &queue_name, + const std::string &binding_key) { + _routing_key = binding_key; + _exchange_name = exchange_name; + _mq_client = std::make_shared(user, passwd, host); + _mq_client->declareComponents(exchange_name, queue_name, binding_key); + } + //构造RPC服务器对象 + void make_rpc_server(uint16_t port, int32_t timeout, uint8_t num_threads) { + if (!_mq_client) { + LOG_ERROR("还未初始化消息队列客户端模块!"); + abort(); + } + if (!_mm_channels) { + LOG_ERROR("还未初始化信道管理模块!"); + abort(); + } + if (!_mysql_client) { + LOG_ERROR("还未初始化Mysql数据库模块!"); + abort(); + } + + _rpc_server = std::make_shared(); + + TransmiteServiceImpl *transmite_service = new TransmiteServiceImpl( + _user_service_name, _mm_channels, _mysql_client, _exchange_name, _routing_key, _mq_client); + + int ret = _rpc_server->AddService(transmite_service, + brpc::ServiceOwnership::SERVER_OWNS_SERVICE); + if (ret == -1) { + LOG_ERROR("添加Rpc服务失败!"); + abort(); + } + brpc::ServerOptions options; + options.idle_timeout_sec = timeout; + options.num_threads = num_threads; + ret = _rpc_server->Start(port, &options); + if (ret == -1) { + LOG_ERROR("服务启动失败!"); + abort(); + } + } + TransmiteServer::ptr build() { + if (!_service_discoverer) { + LOG_ERROR("还未初始化服务发现模块!"); + abort(); + } + if (!_registry_client) { + LOG_ERROR("还未初始化服务注册模块!"); + abort(); + } + if (!_rpc_server) { + LOG_ERROR("还未初始化RPC服务器模块!"); + abort(); + } + TransmiteServer::ptr server = std::make_shared( + _mysql_client, _service_discoverer, _registry_client, _rpc_server); + return server; + } + private: + std::string _user_service_name; + ServiceManager::ptr _mm_channels; + Discovery::ptr _service_discoverer; + + std::string _routing_key; + std::string _exchange_name; + MQClient::ptr _mq_client; + + Registry::ptr _registry_client; // 服务注册客户端 + std::shared_ptr _mysql_client; //mysql数据库客户端 + std::shared_ptr _rpc_server; +}; +} \ No newline at end of file diff --git a/transmite/test/mysql_test/main.cc b/transmite/test/mysql_test/main.cc new file mode 100644 index 0000000..86ec2fe --- /dev/null +++ b/transmite/test/mysql_test/main.cc @@ -0,0 +1,66 @@ +#include "../../../common/mysql_chat_session_member.hpp" +// #include "../../../odb/chat_session_member.hxx" +// #include "chat_session_member-odb.hxx" +#include + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +void append_test(bite_im::ChatSessionMemeberTable &tb) { + bite_im::ChatSessionMember csm1("会话ID1", "用户ID1"); + tb.append(csm1); + bite_im::ChatSessionMember csm2("会话ID1", "用户ID2"); + tb.append(csm2); + bite_im::ChatSessionMember csm3("会话ID2", "用户ID3"); + tb.append(csm3); +} + +void multi_append_test(bite_im::ChatSessionMemeberTable &tb) { + + // bite_im::ChatSessionMember csm1("会话ID3", "用户ID1"); + // bite_im::ChatSessionMember csm2("会话ID3", "用户ID2"); + // bite_im::ChatSessionMember csm3("会话ID3", "用户ID3"); + // std::vector list = {csm1, csm2, csm3}; + // tb.append(list); + + std::vector list; + list.emplace_back("会话ID3", "用户ID1"); + list.emplace_back("会话ID3", "用户ID2"); + list.emplace_back("会话ID3", "用户ID3"); + tb.append(list); +} + +void remove_test(bite_im::ChatSessionMemeberTable &tb) { + bite_im::ChatSessionMember csm3("会话ID2", "用户ID3"); + tb.remove(csm3); +} + +void ss_members(bite_im::ChatSessionMemeberTable &tb) { + auto res = tb.members("会话ID1"); + for (auto &id : res) { + std::cout << id << std::endl; + } +} + +void remove_all(bite_im::ChatSessionMemeberTable &tb) { + tb.remove("会话ID3"); +} + +int main(int argc, char* argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + auto db = bite_im::ODBFactory::create("root", "123456", "127.0.0.1", "bite_im", "utf8", 0, 1); + + bite_im::ChatSessionMemeberTable csmt(db); + + append_test(csmt); + multi_append_test(csmt); + // remove_test(csmt); + // ss_members(csmt); + + // remove_all(csmt); + return 0; +} \ No newline at end of file diff --git a/transmite/test/transmite_client.cc b/transmite/test/transmite_client.cc new file mode 100644 index 0000000..6adc076 --- /dev/null +++ b/transmite/test/transmite_client.cc @@ -0,0 +1,129 @@ +//speech_server的测试客户端实现 +//1. 进行服务发现--发现speech_server的服务器节点地址信息并实例化的通信信道 +//2. 读取语音文件数据 +//3. 发起语音识别RPC调用 + +#include "etcd.hpp" +#include "channel.hpp" +#include "utils.hpp" +#include +#include +#include +#include "transmite.pb.h" + + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +DEFINE_string(etcd_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(transmite_service, "/service/transmite_service", "服务监控根目录"); + +bite_im::ServiceManager::ptr sm; + +void string_message(const std::string &uid, const std::string &sid, const std::string &msg) { + auto channel = sm->choose(FLAGS_transmite_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::MsgTransmitService_Stub stub(channel.get()); + bite_im::NewMessageReq req; + bite_im::GetTransmitTargetRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid); + req.set_chat_session_id(sid); + req.mutable_message()->set_message_type(bite_im::MessageType::STRING); + req.mutable_message()->mutable_string_message()->set_content(msg); + brpc::Controller cntl; + stub.GetTransmitTarget(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); +} + +void image_message(const std::string &uid, const std::string &sid, const std::string &msg) { + auto channel = sm->choose(FLAGS_transmite_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::MsgTransmitService_Stub stub(channel.get()); + bite_im::NewMessageReq req; + bite_im::GetTransmitTargetRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid); + req.set_chat_session_id(sid); + req.mutable_message()->set_message_type(bite_im::MessageType::IMAGE); + req.mutable_message()->mutable_image_message()->set_image_content(msg); + brpc::Controller cntl; + stub.GetTransmitTarget(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); +} + +void speech_message(const std::string &uid, const std::string &sid, const std::string &msg) { + auto channel = sm->choose(FLAGS_transmite_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::MsgTransmitService_Stub stub(channel.get()); + bite_im::NewMessageReq req; + bite_im::GetTransmitTargetRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid); + req.set_chat_session_id(sid); + req.mutable_message()->set_message_type(bite_im::MessageType::SPEECH); + req.mutable_message()->mutable_speech_message()->set_file_contents(msg); + brpc::Controller cntl; + stub.GetTransmitTarget(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); +} + +void file_message(const std::string &uid, const std::string &sid, + const std::string &filename, const std::string &content) { + auto channel = sm->choose(FLAGS_transmite_service); + if (!channel) { + std::cout << "获取通信信道失败!" << std::endl; + return; + } + bite_im::MsgTransmitService_Stub stub(channel.get()); + bite_im::NewMessageReq req; + bite_im::GetTransmitTargetRsp rsp; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid); + req.set_chat_session_id(sid); + req.mutable_message()->set_message_type(bite_im::MessageType::FILE); + req.mutable_message()->mutable_file_message()->set_file_contents(content); + req.mutable_message()->mutable_file_message()->set_file_name(filename); + req.mutable_message()->mutable_file_message()->set_file_size(content.size()); + brpc::Controller cntl; + stub.GetTransmitTarget(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); +} + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + + //1. 先构造Rpc信道管理对象 + sm = std::make_shared(); + sm->declared(FLAGS_transmite_service); + auto put_cb = std::bind(&bite_im::ServiceManager::onServiceOnline, sm.get(), std::placeholders::_1, std::placeholders::_2); + auto del_cb = std::bind(&bite_im::ServiceManager::onServiceOffline, sm.get(), std::placeholders::_1, std::placeholders::_2); + //2. 构造服务发现对象 + bite_im::Discovery::ptr dclient = std::make_shared(FLAGS_etcd_host, FLAGS_base_service, put_cb, del_cb); + + //3. 通过Rpc信道管理对象,获取提供Echo服务的信道 + string_message("672f-c755e83e-0000", "会话ID1", "吃饭了吗?"); + string_message("ee55-9043bfd7-0001", "会话ID1", "吃的盖浇饭!!"); + image_message("672f-c755e83e-0000", "会话ID1", "可爱表情图片数据"); + speech_message("672f-c755e83e-0000", "会话ID1", "动听猪叫声数据"); + file_message("672f-c755e83e-0000", "会话ID1", "猪爸爸的文件名称", "猪爸爸的文件数据"); + return 0; +} \ No newline at end of file diff --git a/transmite/test/user_client.cc b/transmite/test/user_client.cc new file mode 100644 index 0000000..d94f737 --- /dev/null +++ b/transmite/test/user_client.cc @@ -0,0 +1,80 @@ +#include "etcd.hpp" +#include "channel.hpp" +#include "utils.hpp" +#include +#include +#include +#include "user.pb.h" +#include "base.pb.h" + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +DEFINE_string(etcd_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(user_service, "/service/user_service", "服务监控根目录"); + +bite_im::ServiceManager::ptr user_channels; +void reg_user(const std::string &nickname, const std::string &pswd) { + auto channel = user_channels->choose(FLAGS_user_service);//获取通信信道 + ASSERT_TRUE(channel); + + bite_im::UserRegisterReq req; + req.set_request_id(bite_im::uuid()); + req.set_nickname(nickname); + req.set_password(pswd); + + bite_im::UserRegisterRsp rsp; + brpc::Controller cntl; + bite_im::UserService_Stub stub(channel.get()); + stub.UserRegister(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); +} + +void set_user_avatar(const std::string &uid, const std::string &avatar) { + auto channel = user_channels->choose(FLAGS_user_service);//获取通信信道 + ASSERT_TRUE(channel); + bite_im::SetUserAvatarReq req; + req.set_request_id(bite_im::uuid()); + req.set_user_id(uid); + req.set_session_id("测试登录会话ID"); + req.set_avatar(avatar); + bite_im::SetUserAvatarRsp rsp; + brpc::Controller cntl; + bite_im::UserService_Stub stub(channel.get()); + stub.SetUserAvatar(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); +} + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + user_channels = std::make_shared(); + + user_channels->declared(FLAGS_user_service); + auto put_cb = std::bind(&bite_im::ServiceManager::onServiceOnline, user_channels.get(), std::placeholders::_1, std::placeholders::_2); + auto del_cb = std::bind(&bite_im::ServiceManager::onServiceOffline, user_channels.get(), std::placeholders::_1, std::placeholders::_2); + + //2. 构造服务发现对象 + bite_im::Discovery::ptr dclient = std::make_shared(FLAGS_etcd_host, FLAGS_base_service, put_cb, del_cb); + + reg_user("小猪佩奇", "123456"); + reg_user("小猪乔治", "123456"); + std::string uid1, uid2; + std::cout << "输入佩奇用户ID:"; + std::fflush(stdout); + std::cin >> uid1; + std::cout << "输入乔治用户ID:"; + std::fflush(stdout); + std::cin >> uid2; + set_user_avatar(uid1, "佩奇的头像数据"); + set_user_avatar(uid2, "乔治的头像数据"); + // set_user_avatar("672f-c755e83e-0000", "猪爸爸头像数据"); + // set_user_avatar("ee55-9043bfd7-0001", "猪妈妈头像数据"); + return 0; +} \ No newline at end of file diff --git a/transmite/transmite_server.conf b/transmite/transmite_server.conf new file mode 100644 index 0000000..5fec82b --- /dev/null +++ b/transmite/transmite_server.conf @@ -0,0 +1,24 @@ +-run_mode=true +-log_file=/im/logs/transmite.log +-log_level=0 +-registry_host=http://10.0.0.235:2379 +-instance_name=/transmite_service/instance +-access_host=10.0.0.235:10004 +-listen_port=10004 +-rpc_timeout=-1 +-rpc_threads=1 +-base_service=/service +-user_service=/service/user_service +-mysql_host=10.0.0.235 +-mysql_user=root +-mysql_pswd=123456 +-mysql_db=bite_im +-mysql_cset=utf8 +-mysql_port=0 +-mysql_pool_count=4 +-mq_user=root +-mq_pswd=123456 +-mq_host=10.0.0.235:5672 +-mq_msg_exchange=msg_exchange +-mq_msg_queue=msg_queue +-mq_msg_binding_key=msg_queue \ No newline at end of file diff --git a/user/CMakeLists.txt b/user/CMakeLists.txt new file mode 100644 index 0000000..a421a9d --- /dev/null +++ b/user/CMakeLists.txt @@ -0,0 +1,93 @@ +# 1. 添加cmake版本说明 +cmake_minimum_required(VERSION 3.1.3) +# 2. 声明工程名称 +project(user_server) + +set(target "user_server") + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") + +# 3. 检测并生成ODB框架代码 +# 1. 添加所需的proto映射代码文件名称 +set(proto_path ${CMAKE_CURRENT_SOURCE_DIR}/../proto) +set(proto_files base.proto user.proto file.proto) +# 2. 检测框架代码文件是否已经生成 +set(proto_hxx "") +set(proto_cxx "") +set(proto_srcs "") +foreach(proto_file ${proto_files}) +# 3. 如果没有生成,则预定义生成指令 -- 用于在构建项目之间先生成框架代码 + string(REPLACE ".proto" ".pb.cc" proto_cc ${proto_file}) + string(REPLACE ".proto" ".pb.h" proto_hh ${proto_file}) + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}${proto_cc}) + add_custom_command( + PRE_BUILD + COMMAND protoc + ARGS --cpp_out=${CMAKE_CURRENT_BINARY_DIR} -I ${proto_path} --experimental_allow_proto3_optional ${proto_path}/${proto_file} + DEPENDS ${proto_path}/${proto_file} + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + COMMENT "生成Protobuf框架代码文件:" ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc} + ) + endif() + list(APPEND proto_srcs ${CMAKE_CURRENT_BINARY_DIR}/${proto_cc}) +endforeach() + +# 3. 检测并生成ODB框架代码 +# 1. 添加所需的odb映射代码文件名称 +set(odb_path ${CMAKE_CURRENT_SOURCE_DIR}/../odb) +set(odb_files user.hxx) +# 2. 检测框架代码文件是否已经生成 +set(odb_hxx "") +set(odb_cxx "") +set(odb_srcs "") +foreach(odb_file ${odb_files}) +# 3. 如果没有生成,则预定义生成指令 -- 用于在构建项目之间先生成框架代码 + string(REPLACE ".hxx" "-odb.hxx" odb_hxx ${odb_file}) + string(REPLACE ".hxx" "-odb.cxx" odb_cxx ${odb_file}) + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}${odb_cxx}) + add_custom_command( + PRE_BUILD + COMMAND odb + ARGS -d mysql --std c++11 --generate-query --generate-schema --profile boost/date-time ${odb_path}/${odb_file} + DEPENDS ${odb_path}/${odb_file} + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${odb_cxx} + COMMENT "生成ODB框架代码文件:" ${CMAKE_CURRENT_BINARY_DIR}/${odb_cxx} + ) + endif() +# 4. 将所有生成的框架源码文件名称保存起来 student-odb.cxx classes-odb.cxx + list(APPEND odb_srcs ${CMAKE_CURRENT_BINARY_DIR}/${odb_cxx}) +endforeach() + +# 4. 获取源码目录下的所有源码文件 +set(src_files "") +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/source src_files) +# 5. 声明目标及依赖 +add_executable(${target} ${src_files} ${proto_srcs} ${odb_srcs}) +# 7. 设置需要连接的库 +target_link_libraries(${target} -lgflags + -lspdlog -lfmt -lbrpc -lssl -lcrypto + -lprotobuf -lleveldb -letcd-cpp-api + -lcpprest -lcurl -lodb-mysql -lodb -lodb-boost + /usr/lib/x86_64-linux-gnu/libjsoncpp.so.19 + -lalibabacloud-sdk-core -lcpr -lelasticlient + -lhiredis -lredis++) + + +set(test_client "user_client") +set(test_files "") +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/test test_files) +add_executable(${test_client} ${test_files} ${proto_srcs}) +target_link_libraries(${test_client} -pthread + -lgtest -lgflags -lspdlog -lfmt -lbrpc + -lssl -lcrypto -lprotobuf -lleveldb + -letcd-cpp-api -lcpprest -lcurl + /usr/lib/x86_64-linux-gnu/libjsoncpp.so.19) + +# 6. 设置头文件默认搜索路径 +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../common) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../odb) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third/include) + +#8. 设置安装路径 +INSTALL(TARGETS ${target} ${test_client} RUNTIME DESTINATION bin) \ No newline at end of file diff --git a/user/dockerfile b/user/dockerfile new file mode 100644 index 0000000..702c6f6 --- /dev/null +++ b/user/dockerfile @@ -0,0 +1,16 @@ +# 声明基础经镜像来源 +FROM debian:12 + +# 声明工作目录 +WORKDIR /im +RUN mkdir -p /im/logs &&\ + mkdir -p /im/data &&\ + mkdir -p /im/conf &&\ + mkdir -p /im/bin + +# 将可执行程序依赖,拷贝进镜像 +COPY ./build/user_server /im/bin/ +# 将可执行程序文件,拷贝进镜像 +COPY ./depends /lib/x86_64-linux-gnu/ +# 设置容器启动的默认操作 ---运行程序 +CMD /im/bin/user_server -flagfile=/im/conf/user_server.conf \ No newline at end of file diff --git a/user/source/user_server.cc b/user/source/user_server.cc new file mode 100644 index 0000000..2f5f7bf --- /dev/null +++ b/user/source/user_server.cc @@ -0,0 +1,59 @@ +//主要实现语音识别子服务的服务器的搭建 +#include "user_server.hpp" + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +DEFINE_string(registry_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(instance_name, "/user_service/instance", "当前实例名称"); +DEFINE_string(access_host, "127.0.0.1:10003", "当前实例的外部访问地址"); + +DEFINE_int32(listen_port, 10003, "Rpc服务器监听端口"); +DEFINE_int32(rpc_timeout, -1, "Rpc调用超时时间"); +DEFINE_int32(rpc_threads, 1, "Rpc的IO线程数量"); + + +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(file_service, "/service/file_service", "文件管理子服务名称"); + +DEFINE_string(es_host, "http://127.0.0.1:9200/", "ES搜索引擎服务器URL"); + +DEFINE_string(mysql_host, "127.0.0.1", "Mysql服务器访问地址"); +DEFINE_string(mysql_user, "root", "Mysql服务器访问用户名"); +DEFINE_string(mysql_pswd, "123456", "Mysql服务器访问密码"); +DEFINE_string(mysql_db, "bite_im", "Mysql默认库名称"); +DEFINE_string(mysql_cset, "utf8", "Mysql客户端字符集"); +DEFINE_int32(mysql_port, 0, "Mysql服务器访问端口"); +DEFINE_int32(mysql_pool_count, 4, "Mysql连接池最大连接数量"); + + +DEFINE_string(redis_host, "127.0.0.1", "Redis服务器访问地址"); +DEFINE_int32(redis_port, 6379, "Redis服务器访问端口"); +DEFINE_int32(redis_db, 0, "Redis默认库号"); +DEFINE_bool(redis_keep_alive, true, "Redis长连接保活选项"); + + +DEFINE_string(dms_key_id, "LTAI5t6NF7vt499UeqYX6LB9", "短信平台密钥ID"); +DEFINE_string(dms_key_secret, "5hx1qvpXHDKfQDk73aJs6j53Q8KcF2", "短信平台密钥"); + + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + bite_im::UserServerBuilder usb; + // usb.make_dms_object(FLAGS_dms_key_id, FLAGS_dms_key_secret); + usb.make_SendE_object(); + usb.make_es_object({FLAGS_es_host}); + usb.make_mysql_object(FLAGS_mysql_user, FLAGS_mysql_pswd, FLAGS_mysql_host, + FLAGS_mysql_db, FLAGS_mysql_cset, FLAGS_mysql_port, FLAGS_mysql_pool_count); + usb.make_redis_object(FLAGS_redis_host, FLAGS_redis_port, FLAGS_redis_db, FLAGS_redis_keep_alive); + usb.make_discovery_object(FLAGS_registry_host, FLAGS_base_service, FLAGS_file_service); + usb.make_rpc_server(FLAGS_listen_port, FLAGS_rpc_timeout, FLAGS_rpc_threads); + usb.make_registry_object(FLAGS_registry_host, FLAGS_base_service + FLAGS_instance_name, FLAGS_access_host); + auto server = usb.build(); + server->start(); + return 0; +} \ No newline at end of file diff --git a/user/source/user_server.hpp b/user/source/user_server.hpp new file mode 100644 index 0000000..96f3203 --- /dev/null +++ b/user/source/user_server.hpp @@ -0,0 +1,800 @@ +//实现语音识别子服务 +#include +#include + +#include "data_es.hpp" // es数据管理客户端封装 +#include "data_redis.hpp" // redis数据管理客户端封装 +#include "mysql_user.hpp" // mysql数据管理客户端封装 +#include "etcd.hpp" // 服务注册模块封装 +#include "logger.hpp" // 日志模块封装 +#include "utils.hpp" // 基础工具接口 +// #include "dms.hpp" // 短信平台SDK模块封装 +#include "sendemail.hpp" //邮件发送模块 +#include "channel.hpp" // 信道管理模块封装 + +#include "user.hxx" +#include "user-odb.hxx" + +#include "user.pb.h" // protobuf框架代码 +#include "base.pb.h" // protobuf框架代码 +#include "file.pb.h" // protobuf框架代码 + +namespace bite_im{ +class UserServiceImpl : public bite_im::UserService { + public: + UserServiceImpl(const SendEmail::ptr &SendE_client, + //const DMSClient::ptr &dms_client, + const std::shared_ptr &es_client, + const std::shared_ptr &mysql_client, + const std::shared_ptr &redis_client, + const ServiceManager::ptr &channel_manager, + const std::string &file_service_name) : + _es_user(std::make_shared(es_client)), + _mysql_user(std::make_shared(mysql_client)), + _redis_session(std::make_shared(redis_client)), + _redis_status(std::make_shared(redis_client)), + _redis_codes(std::make_shared(redis_client)), + _file_service_name(file_service_name), + _mm_channels(channel_manager), + // _dms_client(dms_client) + _SendE_client(SendE_client) + { + _es_user->createIndex(); + } + ~UserServiceImpl(){} + bool nickname_check(const std::string &nickname) { + return nickname.size() < 22; + } + bool password_check(const std::string &password) { + if (password.size() < 6 || password.size() > 15) { + LOG_ERROR("密码长度不合法:{}-{}", password, password.size()); + return false; + } + for (int i = 0; i < password.size(); i++) { + if (!((password[i] > 'a' && password[i] < 'z') || + (password[i] > 'A' && password[i] < 'Z') || + (password[i] > '0' && password[i] < '9') || + password[i] == '_' || password[i] == '-')) { + LOG_ERROR("密码字符不合法:{}", password); + return false; + } + } + return true; + } + virtual void UserRegister(::google::protobuf::RpcController* controller, + const ::bite_im::UserRegisterReq* request, + ::bite_im::UserRegisterRsp* response, + ::google::protobuf::Closure* done) { + LOG_DEBUG("收到用户注册请求!"); + brpc::ClosureGuard rpc_guard(done); + //定义一个错误处理函数,当出错的时候被调用 + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //1. 从请求中取出昵称和密码 + std::string nickname = request->nickname(); + std::string password = request->password(); + //2. 检查昵称是否合法(只能包含字母,数字,连字符-,下划线_,长度限制 3~15 之间) + bool ret = nickname_check(nickname); + if (ret == false) { + LOG_ERROR("{} - 用户名长度不合法!", request->request_id()); + return err_response(request->request_id(), "用户名长度不合法!"); + } + //3. 检查密码是否合法(只能包含字母,数字,长度限制 6~15 之间) + ret = password_check(password); + if (ret == false) { + LOG_ERROR("{} - 密码格式不合法!", request->request_id()); + return err_response(request->request_id(), "密码格式不合法!"); + } + //4. 根据昵称在数据库进行判断是否昵称已存在 + auto user = _mysql_user->select_by_nickname(nickname); + if (user) { + LOG_ERROR("{} - 用户名被占用- {}!", request->request_id(), nickname); + return err_response(request->request_id(), "用户名被占用!"); + } + //5. 向数据库新增数据 + std::string uid = uuid(); + user = std::make_shared(uid, nickname, password); + ret = _mysql_user->insert(user); + if (ret == false) { + LOG_ERROR("{} - Mysql数据库新增数据失败!", request->request_id()); + return err_response(request->request_id(), "Mysql数据库新增数据失败!"); + } + //6. 向 ES 服务器中新增用户信息 + ret = _es_user->appendData(uid, "", nickname, "", ""); + if (ret == false) { + LOG_ERROR("{} - ES搜索引擎新增数据失败!", request->request_id()); + return err_response(request->request_id(), "ES搜索引擎新增数据失败!"); + } + //7. 组织响应,进行成功与否的响应即可。 + response->set_request_id(request->request_id()); + response->set_success(true); + } + virtual void UserLogin(::google::protobuf::RpcController* controller, + const ::bite_im::UserLoginReq* request, + ::bite_im::UserLoginRsp* response, + ::google::protobuf::Closure* done){ + LOG_DEBUG("收到用户登录请求!"); + brpc::ClosureGuard rpc_guard(done); + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //1. 从请求中取出昵称和密码 + std::string nickname = request->nickname(); + std::string password = request->password(); + //2. 通过昵称获取用户信息,进行密码是否一致的判断 + auto user = _mysql_user->select_by_nickname(nickname); + if (!user || password != user->password()) { + LOG_ERROR("{} - 用户名或密码错误 - {}-{}!", request->request_id(), nickname, password); + return err_response(request->request_id(), "用户名或密码错误!"); + } + //3. 根据 redis 中的登录标记信息是否存在判断用户是否已经登录。 + bool ret = _redis_status->exists(user->user_id()); + if (ret == true) { + LOG_ERROR("{} - 用户已在其他地方登录 - {}!", request->request_id(), nickname); + return err_response(request->request_id(), "用户已在其他地方登录!"); + } + //4. 构造会话 ID,生成会话键值对,向 redis 中添加会话信息以及登录标记信息 + std::string ssid = uuid(); + _redis_session->append(ssid, user->user_id()); + //5. 添加用户登录信息 + _redis_status->append(user->user_id()); + //5. 组织响应,返回生成的会话 ID + response->set_request_id(request->request_id()); + response->set_login_session_id(ssid); + response->set_success(true); + } + bool phone_check(const std::string &phone) { + if (phone.size() != 11) return false; + if (phone[0] != '1') return false; + if (phone[1] < '3' || phone[1] > '9') return false; + for (int i = 2; i < 11; i++) { + if (phone[i] < '0' || phone[i] > '9') return false; + } + return true; + } + virtual void GetPhoneVerifyCode(::google::protobuf::RpcController* controller, + const ::bite_im::PhoneVerifyCodeReq* request, + ::bite_im::PhoneVerifyCodeRsp* response, + ::google::protobuf::Closure* done){ + LOG_DEBUG("收到短信验证码获取请求!"); + brpc::ClosureGuard rpc_guard(done); + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + // 1. 从请求中取出手机号码 + std::string phone = request->phone_number(); + // 2. 验证手机号码格式是否正确(必须以 1 开始,第二位 3~9 之间,后边 9 个数字字符) + //由于此处使用的是邮箱,所以进行邮箱格式的验证 + // bool ret = phone_check(phone); + // if (ret == false) { + // LOG_ERROR("{} - 手机号码格式错误 - {}!", request->request_id(), phone); + // return err_response(request->request_id(), "手机号码格式错误!"); + // } + // 3. 生成 4 位随机验证码 + // std::string code_id = uuid(); + // std::string code = vcode(); + // // 4. 基于短信平台 SDK 发送验证码 + // ret = _dms_client->send(phone, code); + // if (ret == false) { + // LOG_ERROR("{} - 短信验证码发送失败 - {}!", request->request_id(), phone); + // return err_response(request->request_id(), "短信验证码发送失败!"); + // } + std::string code_id = uuid(); + bool ret = _SendE_client->SEND_email(phone); + if (ret == false) { + LOG_ERROR("{} - 邮箱验证码发送失败 - {}!", request->request_id(), phone); + return err_response(request->request_id(), "邮箱验证码发送失败!"); + } + // 5. 构造验证码 ID,添加到 redis 验证码映射键值索引中 + _redis_codes->append(code_id, _SendE_client->getVerifyCode()); + // 6. 组织响应,返回生成的验证码 ID + response->set_request_id(request->request_id()); + response->set_success(true); + response->set_verify_code_id(code_id); + LOG_DEBUG("获取邮箱验证码处理完成!"); + } + virtual void PhoneRegister(::google::protobuf::RpcController* controller, + const ::bite_im::PhoneRegisterReq* request, + ::bite_im::PhoneRegisterRsp* response, + ::google::protobuf::Closure* done){ + LOG_DEBUG("收到邮箱注册请求!"); + brpc::ClosureGuard rpc_guard(done); + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + // 1. 从请求中取出手机号码和验证码,验证码ID + std::string phone = request->phone_number(); + std::string code_id = request->verify_code_id(); + std::string code = request->verify_code(); + // 2. 检查注册手机号码是否合法 + // bool ret = phone_check(phone); + // if (ret == false) { + // LOG_ERROR("{} - 手机号码格式错误 - {}!", request->request_id(), phone); + // return err_response(request->request_id(), "手机号码格式错误!"); + // } + // 3. 从 redis 数据库中进行验证码 ID-验证码一致性匹配 + auto vcode = _redis_codes->code(code_id); + if (vcode != code) { + LOG_ERROR("{} - 验证码错误 - {}-{}!", request->request_id(), code_id, code); + return err_response(request->request_id(), "验证码错误!"); + } + // 4. 通过数据库查询判断手机号是否已经注册过 + auto user = _mysql_user->select_by_phone(phone); + if (user) { + LOG_ERROR("{} - 该邮箱号已注册过用户 - {}!", request->request_id(), phone); + return err_response(request->request_id(), "该邮箱号已注册过用户!"); + } + // 5. 向数据库新增用户信息 + std::string uid = uuid(); + user = std::make_shared(uid, phone); + bool ret = _mysql_user->insert(user); + if (ret == false) { + LOG_ERROR("{} - 向数据库添加用户信息失败 - {}!", request->request_id(), phone); + return err_response(request->request_id(), "向数据库添加用户信息失败!"); + } + // 6. 向 ES 服务器中新增用户信息 + ret = _es_user->appendData(uid, phone, uid, "", ""); + if (ret == false) { + LOG_ERROR("{} - ES搜索引擎新增数据失败!", request->request_id()); + return err_response(request->request_id(), "ES搜索引擎新增数据失败!"); + } + //7. 组织响应,进行成功与否的响应即可。 + response->set_request_id(request->request_id()); + response->set_success(true); + } + virtual void PhoneLogin(::google::protobuf::RpcController* controller, + const ::bite_im::PhoneLoginReq* request, + ::bite_im::PhoneLoginRsp* response, + ::google::protobuf::Closure* done){ + LOG_DEBUG("收到手机号登录请求!"); + brpc::ClosureGuard rpc_guard(done); + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + // 1. 从请求中取出手机号码和验证码 ID,以及验证码。 + std::string phone = request->phone_number(); + std::string code_id = request->verify_code_id(); + std::string code = request->verify_code(); + // 2. 检查注册手机号码是否合法 + // bool ret = phone_check(phone); + // if (ret == false) { + // LOG_ERROR("{} - 手机号码格式错误 - {}!", request->request_id(), phone); + // return err_response(request->request_id(), "手机号码格式错误!"); + // } + // 3. 根据手机号从数据数据进行用户信息查询,判断用用户是否存在 + auto user = _mysql_user->select_by_phone(phone); + if (!user) { + LOG_ERROR("{} - 该手机号未注册用户 - {}!", request->request_id(), phone); + return err_response(request->request_id(), "该手机号未注册用户!"); + } + // 4. 从 redis 数据库中进行验证码 ID-验证码一致性匹配 + auto vcode = _redis_codes->code(code_id); + if (vcode != code) { + LOG_ERROR("{} - 验证码错误 - {}-{}!", request->request_id(), code_id, code); + return err_response(request->request_id(), "验证码错误!"); + } + _redis_codes->remove(code_id); + // 5. 根据 redis 中的登录标记信息是否存在判断用户是否已经登录。 + bool ret = _redis_status->exists(user->user_id()); + if (ret == true) { + LOG_ERROR("{} - 用户已在其他地方登录 - {}!", request->request_id(), phone); + return err_response(request->request_id(), "用户已在其他地方登录!"); + } + //4. 构造会话 ID,生成会话键值对,向 redis 中添加会话信息以及登录标记信息 + std::string ssid = uuid(); + _redis_session->append(ssid, user->user_id()); + //5. 添加用户登录信息 + _redis_status->append(user->user_id()); + // 7. 组织响应,返回生成的会话 ID + response->set_request_id(request->request_id()); + response->set_login_session_id(ssid); + response->set_success(true); + } + + //从这一步开始,用户登录之后才会进行的操作 + virtual void GetUserInfo(::google::protobuf::RpcController* controller, + const ::bite_im::GetUserInfoReq* request, + ::bite_im::GetUserInfoRsp* response, + ::google::protobuf::Closure* done){ + LOG_DEBUG("收到获取单个用户信息请求!"); + brpc::ClosureGuard rpc_guard(done); + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + // 1. 从请求中取出用户 ID + std::string uid = request->user_id(); + // 2. 通过用户 ID,从数据库中查询用户信息 + auto user = _mysql_user->select_by_id(uid); + if (!user) { + LOG_ERROR("{} - 未找到用户信息 - {}!", request->request_id(), uid); + return err_response(request->request_id(), "未找到用户信息!"); + } + // 3. 根据用户信息中的头像 ID,从文件服务器获取头像文件数据,组织完整用户信息 + UserInfo *user_info = response->mutable_user_info(); + user_info->set_user_id(user->user_id()); + user_info->set_nickname(user->nickname()); + user_info->set_description(user->description()); + user_info->set_phone(user->phone()); + + if (!user->avatar_id().empty()) { + //从信道管理对象中,获取到连接了文件管理子服务的channel + auto channel = _mm_channels->choose(_file_service_name); + if (!channel) { + LOG_ERROR("{} - 未找到文件管理子服务节点 - {} - {}!", + request->request_id(), _file_service_name, uid); + return err_response(request->request_id(), "未找到文件管理子服务节点!"); + } + //进行文件子服务的rpc请求,进行头像文件下载 + bite_im::FileService_Stub stub(channel.get()); + bite_im::GetSingleFileReq req; + bite_im::GetSingleFileRsp rsp; + req.set_request_id(request->request_id()); + req.set_file_id(user->avatar_id()); + brpc::Controller cntl; + stub.GetSingleFile(&cntl, &req, &rsp, nullptr); + if (cntl.Failed() == true || rsp.success() == false) { + LOG_ERROR("{} - 文件子服务调用失败:{}!", request->request_id(), cntl.ErrorText()); + return err_response(request->request_id(), "文件子服务调用失败!"); + } + user_info->set_avatar(rsp.file_data().file_content()); + } + // 4. 组织响应,返回用户信息 + response->set_request_id(request->request_id()); + response->set_success(true); + } + + virtual void GetMultiUserInfo(::google::protobuf::RpcController* controller, + const ::bite_im::GetMultiUserInfoReq* request, + ::bite_im::GetMultiUserInfoRsp* response, + ::google::protobuf::Closure* done){ + LOG_DEBUG("收到批量用户信息获取请求!"); + brpc::ClosureGuard rpc_guard(done); + //1. 定义错误回调 + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + //2. 从请求中取出用户ID --- 列表 + std::vector uid_lists; + for (int i = 0; i < request->users_id_size(); i++) { + uid_lists.push_back(request->users_id(i)); + } + //3. 从数据库进行批量用户信息查询 + auto users = _mysql_user->select_multi_users(uid_lists); + if (users.size() != request->users_id_size()) { + LOG_ERROR("{} - 从数据库查找的用户信息数量不一致 {}-{}!", + request->request_id(), request->users_id_size(), users.size()); + return err_response(request->request_id(), "从数据库查找的用户信息数量不一致!"); + } + //4. 批量从文件管理子服务进行文件下载 + auto channel = _mm_channels->choose(_file_service_name); + if (!channel) { + LOG_ERROR("{} - 未找到文件管理子服务节点 - {}!", request->request_id(), _file_service_name); + return err_response(request->request_id(), "未找到文件管理子服务节点!"); + } + bite_im::FileService_Stub stub(channel.get()); + bite_im::GetMultiFileReq req; + bite_im::GetMultiFileRsp rsp; + req.set_request_id(request->request_id()); + for (auto &user : users) { + if (user.avatar_id().empty()) continue; + req.add_file_id_list(user.avatar_id()); + } + brpc::Controller cntl; + stub.GetMultiFile(&cntl, &req, &rsp, nullptr); + if (cntl.Failed() == true || rsp.success() == false) { + LOG_ERROR("{} - 文件子服务调用失败:{} - {}!", request->request_id(), + _file_service_name, cntl.ErrorText()); + return err_response(request->request_id(), "文件子服务调用失败!"); + } + //5. 组织响应() + for (auto &user : users) { + auto user_map = response->mutable_users_info();//本次请求要响应的用户信息map + auto file_map = rsp.mutable_file_data(); //这是批量文件请求响应中的map + UserInfo user_info; + user_info.set_user_id(user.user_id()); + user_info.set_nickname(user.nickname()); + user_info.set_description(user.description()); + user_info.set_phone(user.phone()); + user_info.set_avatar((*file_map)[user.avatar_id()].file_content()); + (*user_map)[user_info.user_id()] = user_info; + } + response->set_request_id(request->request_id()); + response->set_success(true); + } + + virtual void SetUserAvatar(::google::protobuf::RpcController* controller, + const ::bite_im::SetUserAvatarReq* request, + ::bite_im::SetUserAvatarRsp* response, + ::google::protobuf::Closure* done) + { + LOG_DEBUG("收到用户头像设置请求!"); + brpc::ClosureGuard rpc_guard(done); + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + // 1. 从请求中取出用户 ID 与头像数据 + std::string uid = request->user_id(); + // 2. 从数据库通过用户 ID 进行用户信息查询,判断用户是否存在 + auto user = _mysql_user->select_by_id(uid); + if (!user) { + LOG_ERROR("{} - 未找到用户信息 - {}!", request->request_id(), uid); + return err_response(request->request_id(), "未找到用户信息!"); + } + // 3. 上传头像文件到文件子服务, + auto channel = _mm_channels->choose(_file_service_name); + if (!channel) { + LOG_ERROR("{} - 未找到文件管理子服务节点 - {}!", request->request_id(), _file_service_name); + return err_response(request->request_id(), "未找到文件管理子服务节点!"); + } + bite_im::FileService_Stub stub(channel.get()); + bite_im::PutSingleFileReq req; + bite_im::PutSingleFileRsp rsp; + req.set_request_id(request->request_id()); + req.mutable_file_data()->set_file_name(""); + req.mutable_file_data()->set_file_size(request->avatar().size()); + req.mutable_file_data()->set_file_content(request->avatar()); + brpc::Controller cntl; + stub.PutSingleFile(&cntl, &req, &rsp, nullptr); + if (cntl.Failed() == true || rsp.success() == false) { + LOG_ERROR("{} - 文件子服务调用失败:{}!", request->request_id(), cntl.ErrorText()); + return err_response(request->request_id(), "文件子服务调用失败!"); + } + std::string avatar_id = rsp.file_info().file_id(); + // 4. 将返回的头像文件 ID 更新到数据库中 + user->avatar_id(avatar_id); + bool ret = _mysql_user->update(user); + if (ret == false) { + LOG_ERROR("{} - 更新数据库用户头像ID失败 :{}!", request->request_id(), avatar_id); + return err_response(request->request_id(), "更新数据库用户头像ID失败!"); + } + // 5. 更新 ES 服务器中用户信息 + ret = _es_user->appendData(user->user_id(), user->phone(), + user->nickname(), user->description(), user->avatar_id()); + if (ret == false) { + LOG_ERROR("{} - 更新搜索引擎用户头像ID失败 :{}!", request->request_id(), avatar_id); + return err_response(request->request_id(), "更新搜索引擎用户头像ID失败!"); + } + // 6. 组织响应,返回更新成功与否 + response->set_request_id(request->request_id()); + response->set_success(true); + } + + virtual void SetUserNickname(::google::protobuf::RpcController* controller, + const ::bite_im::SetUserNicknameReq* request, + ::bite_im::SetUserNicknameRsp* response, + ::google::protobuf::Closure* done){ + LOG_DEBUG("收到用户昵称设置请求!"); + brpc::ClosureGuard rpc_guard(done); + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + // 1. 从请求中取出用户 ID 与新的昵称 + std::string uid = request->user_id(); + std::string new_nickname = request->nickname(); + // 2. 判断昵称格式是否正确 + bool ret = nickname_check(new_nickname); + if (ret == false) { + LOG_ERROR("{} - 用户名长度不合法!", request->request_id()); + return err_response(request->request_id(), "用户名长度不合法!"); + } + // 3. 从数据库通过用户 ID 进行用户信息查询,判断用户是否存在 + auto user = _mysql_user->select_by_id(uid); + if (!user) { + LOG_ERROR("{} - 未找到用户信息 - {}!", request->request_id(), uid); + return err_response(request->request_id(), "未找到用户信息!"); + } + // 4. 将新的昵称更新到数据库中 + user->nickname(new_nickname); + ret = _mysql_user->update(user); + if (ret == false) { + LOG_ERROR("{} - 更新数据库用户昵称失败 :{}!", request->request_id(), new_nickname); + return err_response(request->request_id(), "更新数据库用户昵称失败!"); + } + // 5. 更新 ES 服务器中用户信息 + ret = _es_user->appendData(user->user_id(), user->phone(), + user->nickname(), user->description(), user->avatar_id()); + if (ret == false) { + LOG_ERROR("{} - 更新搜索引擎用户昵称失败 :{}!", request->request_id(), new_nickname); + return err_response(request->request_id(), "更新搜索引擎用户昵称失败!"); + } + // 6. 组织响应,返回更新成功与否 + response->set_request_id(request->request_id()); + response->set_success(true); + } + + virtual void SetUserDescription(::google::protobuf::RpcController* controller, + const ::bite_im::SetUserDescriptionReq* request, + ::bite_im::SetUserDescriptionRsp* response, + ::google::protobuf::Closure* done){ + LOG_DEBUG("收到用户签名设置请求!"); + brpc::ClosureGuard rpc_guard(done); + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + // 1. 从请求中取出用户 ID 与新的昵称 + std::string uid = request->user_id(); + std::string new_description = request->description(); + // 3. 从数据库通过用户 ID 进行用户信息查询,判断用户是否存在 + auto user = _mysql_user->select_by_id(uid); + if (!user) { + LOG_ERROR("{} - 未找到用户信息 - {}!", request->request_id(), uid); + return err_response(request->request_id(), "未找到用户信息!"); + } + // 4. 将新的昵称更新到数据库中 + user->description(new_description); + bool ret = _mysql_user->update(user); + if (ret == false) { + LOG_ERROR("{} - 更新数据库用户签名失败 :{}!", request->request_id(), new_description); + return err_response(request->request_id(), "更新数据库用户签名失败!"); + } + // 5. 更新 ES 服务器中用户信息 + ret = _es_user->appendData(user->user_id(), user->phone(), + user->nickname(), user->description(), user->avatar_id()); + if (ret == false) { + LOG_ERROR("{} - 更新搜索引擎用户签名失败 :{}!", request->request_id(), new_description); + return err_response(request->request_id(), "更新搜索引擎用户签名失败!"); + } + // 6. 组织响应,返回更新成功与否 + response->set_request_id(request->request_id()); + response->set_success(true); + } + + virtual void SetUserPhoneNumber(::google::protobuf::RpcController* controller, + const ::bite_im::SetUserPhoneNumberReq* request, + ::bite_im::SetUserPhoneNumberRsp* response, + ::google::protobuf::Closure* done){ + LOG_DEBUG("收到用户邮箱号设置请求!"); + brpc::ClosureGuard rpc_guard(done); + auto err_response = [this, response](const std::string &rid, + const std::string &errmsg) -> void { + response->set_request_id(rid); + response->set_success(false); + response->set_errmsg(errmsg); + return; + }; + // 1. 从请求中取出用户 ID 与新的昵称 + std::string uid = request->user_id(); + std::string new_phone = request->phone_number(); + std::string code = request->phone_verify_code(); + std::string code_id = request->phone_verify_code_id(); + // 2. 对验证码进行验证 + auto vcode = _redis_codes->code(code_id); + if (vcode != code) { + LOG_ERROR("{} - 验证码错误 - {}-{}!", request->request_id(), code_id, code); + return err_response(request->request_id(), "验证码错误!"); + } + // 3. 从数据库通过用户 ID 进行用户信息查询,判断用户是否存在 + auto user = _mysql_user->select_by_id(uid); + if (!user) { + LOG_ERROR("{} - 未找到用户信息 - {}!", request->request_id(), uid); + return err_response(request->request_id(), "未找到用户信息!"); + } + // 4. 将新的昵称更新到数据库中 + user->phone(new_phone); + bool ret = _mysql_user->update(user); + if (ret == false) { + LOG_ERROR("{} - 更新数据库用户邮箱号失败 :{}!", request->request_id(), new_phone); + return err_response(request->request_id(), "更新数据库用户邮箱号失败!"); + } + // 5. 更新 ES 服务器中用户信息 + ret = _es_user->appendData(user->user_id(), user->phone(), + user->nickname(), user->description(), user->avatar_id()); + if (ret == false) { + LOG_ERROR("{} - 更新搜索引擎用户邮箱号失败 :{}!", request->request_id(), new_phone); + return err_response(request->request_id(), "更新搜索引擎用户邮箱号失败!"); + } + // 6. 组织响应,返回更新成功与否 + response->set_request_id(request->request_id()); + response->set_success(true); + } + private: + ESUser::ptr _es_user; + UserTable::ptr _mysql_user; + Session::ptr _redis_session; + Status::ptr _redis_status; + Codes::ptr _redis_codes; + //这边是rpc调用客户端相关对象 + std::string _file_service_name; + ServiceManager::ptr _mm_channels; + // DMSClient::ptr _dms_client; + SendEmail::ptr _SendE_client; +}; + +class UserServer { + public: + using ptr = std::shared_ptr; + UserServer(const Discovery::ptr service_discoverer, + const Registry::ptr ®_client, + const std::shared_ptr &es_client, + const std::shared_ptr &mysql_client, + std::shared_ptr &redis_client, + const std::shared_ptr &server): + _service_discoverer(service_discoverer), + _registry_client(reg_client), + _es_client(es_client), + _mysql_client(mysql_client), + _redis_client(redis_client), + _rpc_server(server){} + ~UserServer(){} + //搭建RPC服务器,并启动服务器 + void start() { + _rpc_server->RunUntilAskedToQuit(); + } + private: + Discovery::ptr _service_discoverer; + Registry::ptr _registry_client; + std::shared_ptr _es_client; + std::shared_ptr _mysql_client; + std::shared_ptr _redis_client; + std::shared_ptr _rpc_server; +}; + +class UserServerBuilder { + public: + //构造es客户端对象 + void make_es_object(const std::vector host_list) { + _es_client = ESClientFactory::create(host_list); + } + // void make_dms_object(const std::string &access_key_id, + // const std::string &access_key_secret) { + // _dms_client = std::make_shared(access_key_id, access_key_secret); + // } + void make_SendE_object() { + _SendE_client = std::make_shared(); + } + + //构造mysql客户端对象 + void make_mysql_object( + const std::string &user, + const std::string &pswd, + const std::string &host, + const std::string &db, + const std::string &cset, + int port, + int conn_pool_count) { + _mysql_client = ODBFactory::create(user, pswd, host, db, cset, port, conn_pool_count); + } + //构造redis客户端对象 + void make_redis_object(const std::string &host, + int port, + int db, + bool keep_alive) { + _redis_client = RedisClientFactory::create(host, port, db, keep_alive); + } + //用于构造服务发现客户端&信道管理对象 + void make_discovery_object(const std::string ®_host, + const std::string &base_service_name, + const std::string &file_service_name) { + _file_service_name = file_service_name; + _mm_channels = std::make_shared(); + _mm_channels->declared(file_service_name); + LOG_DEBUG("设置文件子服务为需添加管理的子服务:{}", file_service_name); + auto put_cb = std::bind(&ServiceManager::onServiceOnline, _mm_channels.get(), std::placeholders::_1, std::placeholders::_2); + auto del_cb = std::bind(&ServiceManager::onServiceOffline, _mm_channels.get(), std::placeholders::_1, std::placeholders::_2); + _service_discoverer = std::make_shared(reg_host, base_service_name, put_cb, del_cb); + } + //用于构造服务注册客户端对象 + void make_registry_object(const std::string ®_host, + const std::string &service_name, + const std::string &access_host) { + _registry_client = std::make_shared(reg_host); + _registry_client->registry(service_name, access_host); + } + void make_rpc_server(uint16_t port, int32_t timeout, uint8_t num_threads) { + if (!_es_client) { + LOG_ERROR("还未初始化ES搜索引擎模块!"); + abort(); + } + if (!_mysql_client) { + LOG_ERROR("还未初始化Mysql数据库模块!"); + abort(); + } + if (!_redis_client) { + LOG_ERROR("还未初始化Redis数据库模块!"); + abort(); + } + if (!_mm_channels) { + LOG_ERROR("还未初始化信道管理模块!"); + abort(); + } + if (!_SendE_client) { //_dms_client + // LOG_ERROR("还未初始化短信平台模块!"); + LOG_ERROR("还未初始化发送邮件模块!"); + abort(); + } + _rpc_server = std::make_shared(); + + UserServiceImpl *user_service = new UserServiceImpl(_SendE_client, _es_client, + _mysql_client, _redis_client, _mm_channels, _file_service_name); + int ret = _rpc_server->AddService(user_service, + brpc::ServiceOwnership::SERVER_OWNS_SERVICE); + if (ret == -1) { + LOG_ERROR("添加Rpc服务失败!"); + abort(); + } + brpc::ServerOptions options; + options.idle_timeout_sec = timeout; + options.num_threads = num_threads; + ret = _rpc_server->Start(port, &options); + if (ret == -1) { + LOG_ERROR("服务启动失败!"); + abort(); + } + } + //构造RPC服务器对象 + UserServer::ptr build() { + if (!_service_discoverer) { + LOG_ERROR("还未初始化服务发现模块!"); + abort(); + } + if (!_registry_client) { + LOG_ERROR("还未初始化服务注册模块!"); + abort(); + } + if (!_rpc_server) { + LOG_ERROR("还未初始化RPC服务器模块!"); + abort(); + } + UserServer::ptr server = std::make_shared( + _service_discoverer, _registry_client, + _es_client, _mysql_client, _redis_client, _rpc_server); + return server; + } + private: + Registry::ptr _registry_client; + + std::shared_ptr _es_client; + std::shared_ptr _mysql_client; + std::shared_ptr _redis_client; + + std::string _file_service_name; + ServiceManager::ptr _mm_channels; + Discovery::ptr _service_discoverer; + + // std::shared_ptr _dms_client; + std::shared_ptr _SendE_client; + + std::shared_ptr _rpc_server; +}; +} \ No newline at end of file diff --git a/user/test/es_test/main.cc b/user/test/es_test/main.cc new file mode 100644 index 0000000..8248ab6 --- /dev/null +++ b/user/test/es_test/main.cc @@ -0,0 +1,32 @@ +#include "../../../common/data_es.hpp" +#include + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + + +DEFINE_string(es_host, "http://127.0.0.1:9200/", "es服务器URL"); + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + auto es_client = bite_im::ESClientFactory::create({FLAGS_es_host}); + + auto es_user = std::make_shared(es_client); + es_user->createIndex(); + // es_user->appendData("用户ID1", "手机号1", "小猪佩奇", "这是一只小猪", "小猪头像1"); + // es_user->appendData("用户ID2", "手机号2", "小猪乔治", "这是一只小小猪", "小猪头像2"); + auto res = es_user->search("小猪", {"用户ID1"}); + for (auto &u : res) { + std::cout << "-----------------" << std::endl; + std::cout << u.user_id() << std::endl; + std::cout << *u.phone() << std::endl; + std::cout << *u.nickname() << std::endl; + std::cout << *u.description() << std::endl; + std::cout << *u.avatar_id() << std::endl; + } + return 0; +} \ No newline at end of file diff --git a/user/test/mysql_test/main.cc b/user/test/mysql_test/main.cc new file mode 100644 index 0000000..749caf3 --- /dev/null +++ b/user/test/mysql_test/main.cc @@ -0,0 +1,58 @@ +#include "../../../common/data_mysql.hpp" +#include "../../../odb/user.hxx" +#include "user-odb.hxx" +#include + + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + + +void insert(bite_im::UserTable &user) { + auto user1 = std::make_shared("uid1", "昵称1", "123456"); + user.insert(user1); + + auto user2 = std::make_shared("uid2", "15566667777"); + user.insert(user2); +} + +void update_by_id(bite_im::UserTable &user_tb) { + auto user = user_tb.select_by_id("uid1"); + user->description("我是一个风一样的男子!!"); + user_tb.update(user); +} +void update_by_phone(bite_im::UserTable &user_tb) { + auto user = user_tb.select_by_phone("15566667777"); + user->password("22223333"); + user_tb.update(user); +} +void update_by_nickname(bite_im::UserTable &user_tb) { + auto user = user_tb.select_by_nickname("uid2"); + user->nickname("昵称2"); + user_tb.update(user); +} +void select_users(bite_im::UserTable &user_tb) { + std::vector id_list = {"uid1", "uid2"}; + auto res = user_tb.select_multi_users(id_list); + for (auto user : res) { + std::cout << user.nickname() << std::endl; + } +} + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + auto db = bite_im::ODBFactory::create("root", "123456", "127.0.0.1", "bite_im", "utf8", 0, 1); + + bite_im::UserTable user(db); + + //insert(user); + //update_by_id(user); + //update_by_phone(user); + //update_by_nickname(user); + select_users(user); + return 0; +} \ No newline at end of file diff --git a/user/test/redis_test/main.cc b/user/test/redis_test/main.cc new file mode 100644 index 0000000..1fff8fc --- /dev/null +++ b/user/test/redis_test/main.cc @@ -0,0 +1,84 @@ +#include "../../../common/data_redis.hpp" +#include +#include + + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + + +DEFINE_string(ip, "127.0.0.1", "这是服务器的IP地址,格式:127.0.0.1"); +DEFINE_int32(port, 6379, "这是服务器的端口, 格式: 8080"); +DEFINE_int32(db, 0, "库的编号:默认0号"); +DEFINE_bool(keep_alive, true, "是否进行长连接保活"); + +void session_test(const std::shared_ptr &client) { + bite_im::Session ss(client); + ss.append("会话ID1", "用户ID1"); + ss.append("会话ID2", "用户ID2"); + ss.append("会话ID3", "用户ID3"); + ss.append("会话ID4", "用户ID4"); + + ss.remove("会话ID2"); + ss.remove("会话ID3"); + + auto res1 = ss.uid("会话ID1"); + if (res1) std::cout << *res1 << std::endl; + auto res2 = ss.uid("会话ID2"); + if (res2) std::cout << *res2 << std::endl; + auto res3 = ss.uid("会话ID3"); + if (res3) std::cout << *res3 << std::endl; + auto res4 = ss.uid("会话ID4"); + if (res4) std::cout << *res4 << std::endl; +} + +void status_test(const std::shared_ptr &client) { + bite_im::Status status(client); + status.append("用户ID1"); + status.append("用户ID2"); + status.append("用户ID3"); + + status.remove("用户ID2"); + + if (status.exists("用户ID1")) std::cout << "用户1在线!" << std::endl; + if (status.exists("用户ID2")) std::cout << "用户2在线!" << std::endl; + if (status.exists("用户ID3")) std::cout << "用户3在线!" << std::endl; +} + +void code_test(const std::shared_ptr &client) { + bite_im::Codes codes(client); + codes.append("验证码ID1", "验证码1"); + codes.append("验证码ID2", "验证码2"); + codes.append("验证码ID3", "验证码3"); + + codes.remove("验证码ID2"); + + auto y1 = codes.code("验证码ID1"); + auto y2 = codes.code("验证码ID2"); + auto y3 = codes.code("验证码ID3"); + if (y1) std::cout << *y1 << std::endl; + if (y2) std::cout << *y2 << std::endl; + if (y3) std::cout << *y3 << std::endl; + + std::this_thread::sleep_for(std::chrono::seconds(4)); + auto y4 = codes.code("验证码ID1"); + auto y5 = codes.code("验证码ID2"); + auto y6 = codes.code("验证码ID3"); + if (!y4) std::cout << "验证码ID1不存在" << std::endl; + if (!y5) std::cout << "验证码ID2不存在" << std::endl; + if (!y6) std::cout << "验证码ID3不存在" << std::endl; +} + +int main(int argc, char *argv[]) +{ + google::ParseCommandLineFlags(&argc, &argv, true); + //bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + auto client = bite_im::RedisClientFactory::create(FLAGS_ip, FLAGS_port, FLAGS_db, FLAGS_keep_alive); + + //session_test(client); + //status_test(client); + code_test(client); + return 0; +} diff --git a/user/test/user_client.cc b/user/test/user_client.cc new file mode 100644 index 0000000..0efa26e --- /dev/null +++ b/user/test/user_client.cc @@ -0,0 +1,308 @@ +#include "etcd.hpp" +#include "channel.hpp" +#include "utils.hpp" +#include +#include +#include +#include "user.pb.h" +#include "base.pb.h" + +DEFINE_bool(run_mode, false, "程序的运行模式,false-调试; true-发布;"); +DEFINE_string(log_file, "", "发布模式下,用于指定日志的输出文件"); +DEFINE_int32(log_level, 0, "发布模式下,用于指定日志输出等级"); + +DEFINE_string(etcd_host, "http://127.0.0.1:2379", "服务注册中心地址"); +DEFINE_string(base_service, "/service", "服务监控根目录"); +DEFINE_string(user_service, "/service/user_service", "服务监控根目录"); + +bite_im::ServiceManager::ptr _user_channels; + +bite_im::UserInfo user_info; + +std::string login_ssid; +std::string new_nickname = "亲爱的猪妈妈"; + +//测试已通过 +// TEST(用户子服务测试, 用户注册测试) { +// auto channel = _user_channels->choose(FLAGS_user_service);//获取通信信道 +// ASSERT_TRUE(channel); +// // user_info.set_nickname("猪爸爸"); + +// bite_im::UserRegisterReq req; +// // req.set_request_id(bite_im::uuid()); +// req.set_request_id(user_info.user_id()); +// req.set_nickname(user_info.nickname()); +// req.set_password("123456"); +// bite_im::UserRegisterRsp rsp; +// brpc::Controller cntl; +// bite_im::UserService_Stub stub(channel.get()); +// stub.UserRegister(&cntl, &req, &rsp, nullptr); +// ASSERT_FALSE(cntl.Failed()); +// ASSERT_TRUE(rsp.success()); +// } + +//测试已通过 +// TEST(用户子服务测试, 用户登录测试) { +// auto channel = _user_channels->choose(FLAGS_user_service);//获取通信信道 +// ASSERT_TRUE(channel); + +// bite_im::UserLoginReq req; +// req.set_request_id(bite_im::uuid()); +// req.set_nickname("猪妈妈"); +// req.set_password("123456"); +// bite_im::UserLoginRsp rsp; +// brpc::Controller cntl; +// bite_im::UserService_Stub stub(channel.get()); +// stub.UserLogin(&cntl, &req, &rsp, nullptr); +// ASSERT_FALSE(cntl.Failed()); +// ASSERT_TRUE(rsp.success()); +// login_ssid = rsp.login_session_id(); +// } + +//测试已通过 +// TEST(用户子服务测试, 用户头像设置测试) { +// auto channel = _user_channels->choose(FLAGS_user_service);//获取通信信道 +// ASSERT_TRUE(channel); + +// bite_im::SetUserAvatarReq req; +// req.set_request_id(bite_im::uuid()); +// req.set_user_id(user_info.user_id()); +// req.set_session_id(login_ssid); +// req.set_avatar(user_info.avatar()); +// bite_im::SetUserAvatarRsp rsp; +// brpc::Controller cntl; +// bite_im::UserService_Stub stub(channel.get()); +// stub.SetUserAvatar(&cntl, &req, &rsp, nullptr); +// ASSERT_FALSE(cntl.Failed()); +// ASSERT_TRUE(rsp.success()); +// } + +//测试已通过 +// TEST(用户子服务测试, 用户签名设置测试) { +// auto channel = _user_channels->choose(FLAGS_user_service);//获取通信信道 +// ASSERT_TRUE(channel); + +// bite_im::SetUserDescriptionReq req; +// req.set_request_id(bite_im::uuid()); +// req.set_user_id(user_info.user_id()); +// req.set_session_id(login_ssid); +// req.set_description(user_info.description()); +// bite_im::SetUserDescriptionRsp rsp; +// brpc::Controller cntl; +// bite_im::UserService_Stub stub(channel.get()); +// stub.SetUserDescription(&cntl, &req, &rsp, nullptr); +// ASSERT_FALSE(cntl.Failed()); +// ASSERT_TRUE(rsp.success()); +// } + +//测试已通过 +// TEST(用户子服务测试, 用户昵称设置测试) { +// auto channel = _user_channels->choose(FLAGS_user_service);//获取通信信道 +// ASSERT_TRUE(channel); + +// bite_im::SetUserNicknameReq req; +// req.set_request_id(bite_im::uuid()); +// req.set_user_id(user_info.user_id()); +// req.set_session_id(login_ssid); +// req.set_nickname(new_nickname); +// bite_im::SetUserNicknameRsp rsp; +// brpc::Controller cntl; +// bite_im::UserService_Stub stub(channel.get()); +// stub.SetUserNickname(&cntl, &req, &rsp, nullptr); +// ASSERT_FALSE(cntl.Failed()); +// ASSERT_TRUE(rsp.success()); +// } + +//测试已通过 +// TEST(用户子服务测试, 用户信息获取测试) { +// auto channel = _user_channels->choose(FLAGS_user_service);//获取通信信道 +// ASSERT_TRUE(channel); + +// bite_im::GetUserInfoReq req; +// req.set_request_id(bite_im::uuid()); +// req.set_user_id(user_info.user_id()); +// req.set_session_id(login_ssid); +// bite_im::GetUserInfoRsp rsp; +// brpc::Controller cntl; +// bite_im::UserService_Stub stub(channel.get()); +// stub.GetUserInfo(&cntl, &req, &rsp, nullptr); +// ASSERT_FALSE(cntl.Failed()); +// ASSERT_TRUE(rsp.success()); +// ASSERT_EQ(user_info.user_id(), rsp.user_info().user_id()); +// ASSERT_EQ(new_nickname, rsp.user_info().nickname()); +// ASSERT_EQ(user_info.description(), rsp.user_info().description()); +// ASSERT_EQ("", rsp.user_info().phone()); +// ASSERT_EQ(user_info.avatar(), rsp.user_info().avatar()); +// } + +//测试已通过 +// void set_user_avatar(const std::string &uid, const std::string &avatar) { +// auto channel = _user_channels->choose(FLAGS_user_service);//获取通信信道 +// ASSERT_TRUE(channel); +// bite_im::SetUserAvatarReq req; +// req.set_request_id(bite_im::uuid()); +// req.set_user_id(uid); +// req.set_session_id(login_ssid); +// req.set_avatar(avatar); +// bite_im::SetUserAvatarRsp rsp; +// brpc::Controller cntl; +// bite_im::UserService_Stub stub(channel.get()); +// stub.SetUserAvatar(&cntl, &req, &rsp, nullptr); +// ASSERT_FALSE(cntl.Failed()); +// ASSERT_TRUE(rsp.success()); +// } + +// TEST(用户子服务测试, 批量用户信息获取测试) { +// set_user_avatar("用户ID1", "小猪佩奇的头像数据"); +// set_user_avatar("用户ID2", "小猪乔治的头像数据"); +// auto channel = _user_channels->choose(FLAGS_user_service);//获取通信信道 +// ASSERT_TRUE(channel); + +// bite_im::GetMultiUserInfoReq req; +// req.set_request_id(bite_im::uuid()); +// req.add_users_id("用户ID1"); +// req.add_users_id("用户ID2"); +// req.add_users_id("ee55-9043bfd7-0001"); +// bite_im::GetMultiUserInfoRsp rsp; +// brpc::Controller cntl; +// bite_im::UserService_Stub stub(channel.get()); +// stub.GetMultiUserInfo(&cntl, &req, &rsp, nullptr); +// ASSERT_FALSE(cntl.Failed()); +// ASSERT_TRUE(rsp.success()); +// auto users_map = rsp.mutable_users_info(); +// bite_im::UserInfo fuser = (*users_map)["ee55-9043bfd7-0001"]; +// ASSERT_EQ(fuser.user_id(), "ee55-9043bfd7-0001"); +// ASSERT_EQ(fuser.nickname(), "猪爸爸"); +// ASSERT_EQ(fuser.description(), "这是第一个用户的描述信息"); +// ASSERT_EQ(fuser.phone(), "13800138003"); +// ASSERT_EQ(fuser.avatar(), ""); + +// bite_im::UserInfo puser = (*users_map)["用户ID1"]; +// ASSERT_EQ(puser.user_id(), "用户ID1"); +// ASSERT_EQ(puser.nickname(), "user_nickname_1"); +// ASSERT_EQ(puser.description(), "这是第一个用户的描述信息"); +// ASSERT_EQ(puser.phone(), "13800138001"); +// ASSERT_EQ(puser.avatar(), "小猪佩奇的头像数据"); + +// bite_im::UserInfo quser = (*users_map)["用户ID2"]; +// ASSERT_EQ(quser.user_id(), "用户ID2"); +// ASSERT_EQ(quser.nickname(), "user_nickname_2"); +// ASSERT_EQ(quser.description(), "这是第二个用户的描述信息"); +// ASSERT_EQ(quser.phone(), "13800138002"); +// ASSERT_EQ(quser.avatar(), "小猪乔治的头像数据"); +// } + +std::string code_id; +void get_code() { + auto channel = _user_channels->choose(FLAGS_user_service);//获取通信信道 + ASSERT_TRUE(channel); + + bite_im::PhoneVerifyCodeReq req; + req.set_request_id(bite_im::uuid()); + req.set_phone_number(user_info.phone()); + bite_im::PhoneVerifyCodeRsp rsp; + brpc::Controller cntl; + bite_im::UserService_Stub stub(channel.get()); + stub.GetPhoneVerifyCode(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); + code_id = rsp.verify_code_id(); +} + + +TEST(用户子服务测试, 手机号注册) { + get_code(); + auto channel = _user_channels->choose(FLAGS_user_service);//获取通信信道 + ASSERT_TRUE(channel); + + bite_im::PhoneRegisterReq req; + req.set_request_id(bite_im::uuid()); + req.set_phone_number(user_info.phone()); + req.set_verify_code_id(code_id); + std::cout << "手机号注册,输入验证码:" << std::endl; + std::string code; + std::cin >> code; + req.set_verify_code(code); + bite_im::PhoneRegisterRsp rsp; + brpc::Controller cntl; + bite_im::UserService_Stub stub(channel.get()); + stub.PhoneRegister(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); +} + + +TEST(用户子服务测试, 手机号登录) { + std::this_thread::sleep_for(std::chrono::seconds(3)); + get_code(); + auto channel = _user_channels->choose(FLAGS_user_service);//获取通信信道 + ASSERT_TRUE(channel); + + bite_im::PhoneLoginReq req; + req.set_request_id(bite_im::uuid()); + req.set_phone_number(user_info.phone()); + req.set_verify_code_id(code_id); + std::cout << "手机号登录,输入验证码:" << std::endl; + std::string code; + std::cin >> code; + req.set_verify_code(code); + bite_im::PhoneLoginRsp rsp; + brpc::Controller cntl; + bite_im::UserService_Stub stub(channel.get()); + stub.PhoneLogin(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); + std::cout << "手机登录会话ID:" << rsp.login_session_id() << std::endl; +} + + +TEST(用户子服务测试, 手机号设置) { + std::this_thread::sleep_for(std::chrono::seconds(10)); + get_code(); + auto channel = _user_channels->choose(FLAGS_user_service);//获取通信信道 + ASSERT_TRUE(channel); + + bite_im::SetUserPhoneNumberReq req; + req.set_request_id(bite_im::uuid()); + std::cout << "手机号设置时,输入用户ID:" << std::endl; + std::string user_id; + std::cin >> user_id; + req.set_user_id(user_id); + req.set_phone_number("2050965275@qq.com"); + req.set_phone_verify_code_id(code_id); + std::cout << "手机号设置时,输入验证码:" << std::endl; + std::string code; + std::cin >> code; + req.set_phone_verify_code(code); + bite_im::SetUserPhoneNumberRsp rsp; + brpc::Controller cntl; + bite_im::UserService_Stub stub(channel.get()); + stub.SetUserPhoneNumber(&cntl, &req, &rsp, nullptr); + ASSERT_FALSE(cntl.Failed()); + ASSERT_TRUE(rsp.success()); +} + +int main(int argc, char *argv[]) +{ + testing::InitGoogleTest(&argc, argv); + google::ParseCommandLineFlags(&argc, &argv, true); + bite_im::init_logger(FLAGS_run_mode, FLAGS_log_file, FLAGS_log_level); + + //1. 先构造Rpc信道管理对象 + _user_channels = std::make_shared(); + _user_channels->declared(FLAGS_user_service); + auto put_cb = std::bind(&bite_im::ServiceManager::onServiceOnline, _user_channels.get(), std::placeholders::_1, std::placeholders::_2); + auto del_cb = std::bind(&bite_im::ServiceManager::onServiceOffline, _user_channels.get(), std::placeholders::_1, std::placeholders::_2); + + //2. 构造服务发现对象 + bite_im::Discovery::ptr dclient = std::make_shared(FLAGS_etcd_host, FLAGS_base_service, put_cb, del_cb); + + user_info.set_nickname("猪妈妈"); + user_info.set_user_id("672f-c755e83e-0000"); + user_info.set_description("这是一个美丽的猪妈妈"); + user_info.set_phone("2050965275@qq.com"); + user_info.set_avatar("猪妈妈头像数据"); + testing::InitGoogleTest(&argc, argv); + LOG_DEBUG("开始测试!"); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/user/user_server.conf b/user/user_server.conf new file mode 100644 index 0000000..8d6816a --- /dev/null +++ b/user/user_server.conf @@ -0,0 +1,25 @@ +-run_mode=true +-log_file=/im/logs/user.log +-log_level=0 +-registry_host=http://10.0.0.235:2379 +-instance_name=/user_service/instance +-access_host=10.0.0.235:10003 +-listen_port=10003 +-rpc_timeout=-1 +-rpc_threads=1 +-base_service=/service +-file_service=/service/file_service +-es_host=http://10.0.0.235:9200/ +-mysql_host=10.0.0.235 +-mysql_user=root +-mysql_pswd=123456 +-mysql_db=bite_im +-mysql_cset=utf8 +-mysql_port=0 +-mysql_pool_count=4 +-redis_host=10.0.0.235 +-redis_port=6379 +-redis_db=0 +-redis_keep_alive=true +-dms_key_id=LTAI5t6NF7vt499UeqYX6LB9 +-dms_key_secret=5hx1qvpXHDKfQDk73aJs6j53Q8KcF2 \ No newline at end of file