CompetitiveProgrammingCpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub

:heavy_check_mark: Library/DataStructure/DisjointSparseTable.hpp

Depends on

Verified with

Code

#pragma once

#include <cmath>

#include <stdexcept>

#include <vector>


#include "../Algebraic/SemiGroup.hpp"


namespace mtd {

  template <semigroup SG>
  class DisjointSparseTable {
    using S = typename SG::value_type;

    const int m_n;
    const std::vector<std::vector<SG>> m_table;

    constexpr static auto accumulation(int n, const std::vector<S>& a, int l,
                                       int r) {
      auto mid = (r + l) >> 1;
      r = std::min(n, r);
      int size = r - l;
      std::vector<SG> acc;
      acc.reserve(size);
      for (int i = l; i < r; ++i) { acc.emplace_back(a[i]); }
      for (int i = mid - 2; i >= l; --i) {
        if (i - l + 1 < size) {
          acc[i - l] = acc[i - l].binaryOperation(acc[i - l + 1]);
        }
      }
      for (int i = mid + 1; i < r; ++i) {
        if (i - l - 1 >= 0) {
          acc[i - l] = acc[i - l - 1].binaryOperation(acc[i - l]);
        }
      }
      return acc;
    }

    constexpr static auto constructTable(int n, const std::vector<S>& a) {
      std::vector<std::vector<SG>> table;
      table.reserve(std::log2(n) + 1);
      table.emplace_back(a.begin(), a.end());

      auto size = 1;
      while (size < n) {
        size <<= 1;
        std::vector<SG> acc;
        acc.reserve(n);
        for (int l = 0; l < n; l += size) {
          for (const auto& x : accumulation(n, a, l, l + size)) {
            acc.emplace_back(x);
          }
        }
        table.emplace_back(acc);
      }
      return table;
    }

    constexpr static auto msb(int x) {
      auto idx = 0;
      while (x > 0) {
        ++idx;
        x >>= 1;
      }
      return idx;
    }

  public:
    DisjointSparseTable(int n, const std::vector<S>& a)
        : m_n(n), m_table(constructTable(n, a)) {}

    constexpr auto get(int l, int r) const {
      if (r < l) { throw std::runtime_error("ERROR! `l` must less than `r`"); }
      l = std::max(l, 0);
      r = std::min(r, m_n - 1);
      if (l == r) { return m_table[0][l].m_val; }
      auto idx = msb(l ^ r);
      return m_table[idx][l].binaryOperation(m_table[idx][r]).m_val;
    }
  };
}  // namespace mtd
#line 2 "Library/DataStructure/DisjointSparseTable.hpp"

#include <cmath>

#include <stdexcept>

#include <vector>


#line 2 "Library/Algebraic/SemiGroup.hpp"

#include <iostream>

namespace mtd {

  template <class S,  // set
            class op  // binary operation
            >
  requires std::is_invocable_r_v<S, op, S, S>
  struct SemiGroup {
    using value_type = S;
    using op_type = op;

    S m_val;
    constexpr SemiGroup(S val) : m_val(val) {}
    constexpr SemiGroup binaryOperation(const SemiGroup& s) const {
      return op()(m_val, s.m_val);
    }
    constexpr friend std::ostream& operator<<(std::ostream& os,
                                              const SemiGroup<S, op>& s) {
      return os << s.m_val;
    }
  };

  namespace __detail {
    template <typename T, template <typename, typename> typename S>
    concept is_semigroup_specialization_of = requires {
      typename std::enable_if_t<
          std::is_same_v<T, S<typename T::value_type, typename T::op_type>>>;
    };
  }  // namespace __detail

  template <typename G>
  concept semigroup = __detail::is_semigroup_specialization_of<G, SemiGroup>;

}  // namespace mtd
#line 8 "Library/DataStructure/DisjointSparseTable.hpp"

namespace mtd {

  template <semigroup SG>
  class DisjointSparseTable {
    using S = typename SG::value_type;

    const int m_n;
    const std::vector<std::vector<SG>> m_table;

    constexpr static auto accumulation(int n, const std::vector<S>& a, int l,
                                       int r) {
      auto mid = (r + l) >> 1;
      r = std::min(n, r);
      int size = r - l;
      std::vector<SG> acc;
      acc.reserve(size);
      for (int i = l; i < r; ++i) { acc.emplace_back(a[i]); }
      for (int i = mid - 2; i >= l; --i) {
        if (i - l + 1 < size) {
          acc[i - l] = acc[i - l].binaryOperation(acc[i - l + 1]);
        }
      }
      for (int i = mid + 1; i < r; ++i) {
        if (i - l - 1 >= 0) {
          acc[i - l] = acc[i - l - 1].binaryOperation(acc[i - l]);
        }
      }
      return acc;
    }

    constexpr static auto constructTable(int n, const std::vector<S>& a) {
      std::vector<std::vector<SG>> table;
      table.reserve(std::log2(n) + 1);
      table.emplace_back(a.begin(), a.end());

      auto size = 1;
      while (size < n) {
        size <<= 1;
        std::vector<SG> acc;
        acc.reserve(n);
        for (int l = 0; l < n; l += size) {
          for (const auto& x : accumulation(n, a, l, l + size)) {
            acc.emplace_back(x);
          }
        }
        table.emplace_back(acc);
      }
      return table;
    }

    constexpr static auto msb(int x) {
      auto idx = 0;
      while (x > 0) {
        ++idx;
        x >>= 1;
      }
      return idx;
    }

  public:
    DisjointSparseTable(int n, const std::vector<S>& a)
        : m_n(n), m_table(constructTable(n, a)) {}

    constexpr auto get(int l, int r) const {
      if (r < l) { throw std::runtime_error("ERROR! `l` must less than `r`"); }
      l = std::max(l, 0);
      r = std::min(r, m_n - 1);
      if (l == r) { return m_table[0][l].m_val; }
      auto idx = msb(l ^ r);
      return m_table[idx][l].binaryOperation(m_table[idx][r]).m_val;
    }
  };
}  // namespace mtd
Back to top page