信号量+令牌桶算法实现的简单限流组件

基于条件变量的信号量简单实现

class Semaphore {
 public:
  Semaphore(int max_count) : max_count_(max_count), current_count_(max_count) {}
  Semaphore() : max_count_(1), current_count_(1) {}

  void set(int target) {
    std::unique_lock<std::mutex> lock(mutex_);
    if (target < max_count_) {
      return;
    }
    max_count_ = target;
    int add = target - max_count_;
    current_count_ += add;
  }

  bool acquire(std::chrono::milliseconds timeout) {
    std::unique_lock<std::mutex> lock(mutex_);
    while (current_count_ == 0) {
      if (!condition_.wait_for(lock, timeout,
                               [this]() { return current_count_ > 0; })) {
        return false;
      }
    }
    --current_count_;
    return true;
  }

  void release() {
    std::unique_lock<std::mutex> lock(mutex_);
    ++current_count_;
    if (current_count_ <= max_count_) {
      condition_.notify_one();
    }
  }

  void release(int count) {
    if (count <= 0 || count > max_count_) {
      return;
    }
    std::unique_lock<std::mutex> lock(mutex_);
    current_count_ += count;
    if (current_count_ <= max_count_) {
      condition_.notify_all();
    }
  }

  int count() {
    std::unique_lock<std::mutex> lock(mutex_);
    return current_count_;
  }

 private:
  int max_count_;
  int current_count_;
  std::mutex mutex_;
  std::condition_variable condition_;
};

限流组件

class RateLimiter {
 private:
  int capacity_;                      // 容量
  int tokens_per_second_;             // 每秒钟放入令牌的数量
  Semaphore tokens_;                  // 令牌数量
  uint64_t free_timestamp_;           // 下次可以放入令牌的时间
  double quota_ms;                    // 产生1个quota需要的毫秒
  std::mutex mutex_;                  // 互斥锁
  std::chrono::milliseconds timeout;  // 超时时间

  // 刷新令牌
  void AddToken() {
    std::lock_guard<std::mutex> lock(mutex_);
    uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>(
                       std::chrono::system_clock::now().time_since_epoch())
                       .count();
    if (now > free_timestamp_) {
      int add_num = (now - free_timestamp_) * tokens_per_second_ / 1000;
      tokens_.release(std::min(capacity_, tokens_.count() + add_num));
      double should_add_num =
          double(now - free_timestamp_) * tokens_per_second_ / 1000;
      free_timestamp_ = now - (should_add_num - add_num) * quota_ms;
    }
  }

 public:
  RateLimiter(int capacity, int tokens_per_second)
      : capacity_(capacity), tokens_per_second_(tokens_per_second) {
    tokens_.set(capacity);
    quota_ms = 1000.0 / tokens_per_second_;
    timeout = std::chrono::milliseconds(1000 / tokens_per_second);
    free_timestamp_ = std::chrono::duration_cast<std::chrono::milliseconds>(
                          std::chrono::system_clock::now().time_since_epoch())
                          .count() +
                      1000;  // 应对突发流量
  }

  void setRate(int capacity, int tokens_per_second){
    capacity_ = capacity;
    tokens_per_second_ = tokens_per_second;
    
  }

  void GetToken() {
    while (!tokens_.acquire(timeout)) {
      AddToken();
    }
  }
};

完整代码

#include <chrono>
#include <condition_variable>
#include <iostream>
#include <mutex>
#include <thread>
#include <vector>

// 基于条件变量实现的信号量
class Semaphore {
 public:
  Semaphore(int max_count) : max_count_(max_count), current_count_(max_count) {}
  Semaphore() : max_count_(1), current_count_(1) {}

  void set(int target) {
    std::unique_lock<std::mutex> lock(mutex_);
    if (target < max_count_) {
      return;
    }
    max_count_ = target;
    int add = target - max_count_;
    current_count_ += add;
  }

  bool acquire(std::chrono::milliseconds timeout) {
    std::unique_lock<std::mutex> lock(mutex_);
    while (current_count_ == 0) {
      if (!condition_.wait_for(lock, timeout,
                               [this]() { return current_count_ > 0; })) {
        return false;
      }
    }
    --current_count_;
    return true;
  }

  void release() {
    std::unique_lock<std::mutex> lock(mutex_);
    ++current_count_;
    if (current_count_ <= max_count_) {
      condition_.notify_one();
    }
  }

  void release(int count) {
    if (count <= 0 || count > max_count_) {
      return;
    }
    std::unique_lock<std::mutex> lock(mutex_);
    current_count_ += count;
    if (current_count_ <= max_count_) {
      condition_.notify_all();
    }
  }

  int count() {
    std::unique_lock<std::mutex> lock(mutex_);
    return current_count_;
  }

 private:
  int max_count_;
  int current_count_;
  std::mutex mutex_;
  std::condition_variable condition_;
};

// 令牌桶限流组件
class RateLimiter {
 private:
  int capacity_;                      // 容量
  int tokens_per_second_;             // 每秒钟放入令牌的数量
  Semaphore tokens_;                  // 令牌数量
  uint64_t free_timestamp_;           // 下次可以放入令牌的时间
  double quota_ms;                    // 产生1个quota需要的毫秒
  std::mutex mutex_;                  // 互斥锁
  std::chrono::milliseconds timeout;  // 超时时间

  // 刷新令牌
  void AddToken() {
    std::lock_guard<std::mutex> lock(mutex_);
    uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>(
                       std::chrono::system_clock::now().time_since_epoch())
                       .count();
    if (now > free_timestamp_) {
      int add_num = (now - free_timestamp_) * tokens_per_second_ / 1000;
      tokens_.release(std::min(capacity_, tokens_.count() + add_num));
      double should_add_num =
          double(now - free_timestamp_) * tokens_per_second_ / 1000;
      free_timestamp_ = now - (should_add_num - add_num) * quota_ms;
    }
  }

 public:
  RateLimiter(int capacity, int tokens_per_second)
      : capacity_(capacity), tokens_per_second_(tokens_per_second) {
    tokens_.set(capacity);
    quota_ms = 1000.0 / tokens_per_second_;
    timeout = std::chrono::milliseconds(1000 / tokens_per_second);
    free_timestamp_ = std::chrono::duration_cast<std::chrono::milliseconds>(
                          std::chrono::system_clock::now().time_since_epoch())
                          .count() +
                      1000;  // 应对突发流量
  }

  void setRate(int capacity, int tokens_per_second){
    capacity_ = capacity;
    tokens_per_second_ = tokens_per_second;
    
  }

  void GetToken() {
    while (!tokens_.acquire(timeout)) {
      AddToken();
    }
  }
};

RateLimiter tb(8, 8);

void pushData() {
  while (true) {
    tb.GetToken();
    std::cout << "send data" << std::endl;
  }
}

int main(int argc, const char **argv) {
  std::vector<std::thread> threads;
  for (int i = 0; i < 10; ++i) {
    threads.push_back(std::thread(pushData));
  }
  for (auto &thread : threads) {
    thread.join();
  }

  return 0;
}

 

------本页内容已结束,喜欢请分享------

文章作者
能不能吃完饭再说
隐私政策
PrivacyPolicy
用户协议
UseGenerator
许可协议
NC-SA 4.0


© 版权声明
THE END
喜欢就支持一下吧
点赞28赞赏 分享
评论 抢沙发
头像
欢迎您留下宝贵的见解!
提交
头像

昵称

取消
昵称表情代码图片