基于条件变量的信号量简单实现
class Semaphore {
public:
Semaphore(int max_count) : max_count_(max_count), current_count_(max_count) {}
Semaphore() : max_count_(1), current_count_(1) {}
void set(int target) {
std::unique_lock<std::mutex> lock(mutex_);
if (target < max_count_) {
return;
}
max_count_ = target;
int add = target - max_count_;
current_count_ += add;
}
bool acquire(std::chrono::milliseconds timeout) {
std::unique_lock<std::mutex> lock(mutex_);
while (current_count_ == 0) {
if (!condition_.wait_for(lock, timeout,
[this]() { return current_count_ > 0; })) {
return false;
}
}
--current_count_;
return true;
}
void release() {
std::unique_lock<std::mutex> lock(mutex_);
++current_count_;
if (current_count_ <= max_count_) {
condition_.notify_one();
}
}
void release(int count) {
if (count <= 0 || count > max_count_) {
return;
}
std::unique_lock<std::mutex> lock(mutex_);
current_count_ += count;
if (current_count_ <= max_count_) {
condition_.notify_all();
}
}
int count() {
std::unique_lock<std::mutex> lock(mutex_);
return current_count_;
}
private:
int max_count_;
int current_count_;
std::mutex mutex_;
std::condition_variable condition_;
};
限流组件
class RateLimiter {
private:
int capacity_; // 容量
int tokens_per_second_; // 每秒钟放入令牌的数量
Semaphore tokens_; // 令牌数量
uint64_t free_timestamp_; // 下次可以放入令牌的时间
double quota_ms; // 产生1个quota需要的毫秒
std::mutex mutex_; // 互斥锁
std::chrono::milliseconds timeout; // 超时时间
// 刷新令牌
void AddToken() {
std::lock_guard<std::mutex> lock(mutex_);
uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
if (now > free_timestamp_) {
int add_num = (now - free_timestamp_) * tokens_per_second_ / 1000;
tokens_.release(std::min(capacity_, tokens_.count() + add_num));
double should_add_num =
double(now - free_timestamp_) * tokens_per_second_ / 1000;
free_timestamp_ = now - (should_add_num - add_num) * quota_ms;
}
}
public:
RateLimiter(int capacity, int tokens_per_second)
: capacity_(capacity), tokens_per_second_(tokens_per_second) {
tokens_.set(capacity);
quota_ms = 1000.0 / tokens_per_second_;
timeout = std::chrono::milliseconds(1000 / tokens_per_second);
free_timestamp_ = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count() +
1000; // 应对突发流量
}
void setRate(int capacity, int tokens_per_second){
capacity_ = capacity;
tokens_per_second_ = tokens_per_second;
}
void GetToken() {
while (!tokens_.acquire(timeout)) {
AddToken();
}
}
};
完整代码
#include <chrono>
#include <condition_variable>
#include <iostream>
#include <mutex>
#include <thread>
#include <vector>
// 基于条件变量实现的信号量
class Semaphore {
public:
Semaphore(int max_count) : max_count_(max_count), current_count_(max_count) {}
Semaphore() : max_count_(1), current_count_(1) {}
void set(int target) {
std::unique_lock<std::mutex> lock(mutex_);
if (target < max_count_) {
return;
}
max_count_ = target;
int add = target - max_count_;
current_count_ += add;
}
bool acquire(std::chrono::milliseconds timeout) {
std::unique_lock<std::mutex> lock(mutex_);
while (current_count_ == 0) {
if (!condition_.wait_for(lock, timeout,
[this]() { return current_count_ > 0; })) {
return false;
}
}
--current_count_;
return true;
}
void release() {
std::unique_lock<std::mutex> lock(mutex_);
++current_count_;
if (current_count_ <= max_count_) {
condition_.notify_one();
}
}
void release(int count) {
if (count <= 0 || count > max_count_) {
return;
}
std::unique_lock<std::mutex> lock(mutex_);
current_count_ += count;
if (current_count_ <= max_count_) {
condition_.notify_all();
}
}
int count() {
std::unique_lock<std::mutex> lock(mutex_);
return current_count_;
}
private:
int max_count_;
int current_count_;
std::mutex mutex_;
std::condition_variable condition_;
};
// 令牌桶限流组件
class RateLimiter {
private:
int capacity_; // 容量
int tokens_per_second_; // 每秒钟放入令牌的数量
Semaphore tokens_; // 令牌数量
uint64_t free_timestamp_; // 下次可以放入令牌的时间
double quota_ms; // 产生1个quota需要的毫秒
std::mutex mutex_; // 互斥锁
std::chrono::milliseconds timeout; // 超时时间
// 刷新令牌
void AddToken() {
std::lock_guard<std::mutex> lock(mutex_);
uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
if (now > free_timestamp_) {
int add_num = (now - free_timestamp_) * tokens_per_second_ / 1000;
tokens_.release(std::min(capacity_, tokens_.count() + add_num));
double should_add_num =
double(now - free_timestamp_) * tokens_per_second_ / 1000;
free_timestamp_ = now - (should_add_num - add_num) * quota_ms;
}
}
public:
RateLimiter(int capacity, int tokens_per_second)
: capacity_(capacity), tokens_per_second_(tokens_per_second) {
tokens_.set(capacity);
quota_ms = 1000.0 / tokens_per_second_;
timeout = std::chrono::milliseconds(1000 / tokens_per_second);
free_timestamp_ = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count() +
1000; // 应对突发流量
}
void setRate(int capacity, int tokens_per_second){
capacity_ = capacity;
tokens_per_second_ = tokens_per_second;
}
void GetToken() {
while (!tokens_.acquire(timeout)) {
AddToken();
}
}
};
RateLimiter tb(8, 8);
void pushData() {
while (true) {
tb.GetToken();
std::cout << "send data" << std::endl;
}
}
int main(int argc, const char **argv) {
std::vector<std::thread> threads;
for (int i = 0; i < 10; ++i) {
threads.push_back(std::thread(pushData));
}
for (auto &thread : threads) {
thread.join();
}
return 0;
}
© 版权声明
文章版权归作者所有,未经允许请勿转载。
THE END