zoj 3494 BCD Code 数位DP + AC自动机

发布时间:2024年01月20日

BCD Code

1

题意

将十进制数的每一位数位转化成一个 4 4 4 位的二进制数,并给定一些 禁止码,规定符合条件的数字的二进制表示中不能包含连续的某个禁止码。问 [ l , r ] [l,r] [l,r] 中有多少个符合条件的数字

思路

朴素的数位 D P DP DP 只涉及少量的禁止码,而这道题涉及的禁止码多达 100 100 100 个,所以要用 A C AC AC 自动机 来完成多模匹配。我们可以用当前匹配的状态在 A C AC AC 自动机的 T r i e Trie Trie 上的节点编号来表示当前的限制。

d p [ p o s ] [ n o w ] dp[pos][now] dp[pos][now] 表示 p o s pos pos 个全变化位,当前搜到了 T r i e Trie Trie 上的 n o w now now 号节点所包含的符合条件的数量。当我们枚举当前位为 0 0 0 ~ 9 9 9 时,我们要从 n o w now now 往下走 4 4 4 步,这每一步都不能走到某个禁止码的末尾,它的某个 f a i l fail fail 指针指向的点也不能是某个禁止码的末尾。由于 T r i e Trie Trie 建好后就固定了,所以我们可以预处理树上每个点往下走的路径是 0 0 0 ~ 9 9 9 10 10 10 种情况的节点编号,如果路上遇到了某个禁止码,那么就移到 ? 1 -1 ?1 节点。

在记忆化搜索的过程中注意一下前导 0 0 0 就可以了。代码中的 Z Z Z 类型是我借鉴 j i a n g l y jiangly jiangly 的大数取模类型,可以当成一个会自动取模 l o n g l o n g long long longlong 就可以

#include<bits/stdc++.h>
#define fore(i,l,r)	for(int i=(int)(l);i<(int)(r);++i)
#define fi first
#define se second
#define endl '\n'
#define ull unsigned long long
#define ALL(v) v.begin(), v.end()
#define Debug(x, ed) std::cerr << #x << " = " << x << ed;

const int INF=0x3f3f3f3e;
const long long INFLL=1e18;

typedef long long ll;

template<class T>
constexpr T power(T a, ll b){
    T res = 1;
    while(b){
        if(b&1) res = res * a;
        a = a * a;
        b >>= 1;
    }
    return res;
}

constexpr ll mul(ll a,ll b,ll mod){ //快速乘,避免两个long long相乘取模溢出
    ll res = a * b - ll(1.L * a * b / mod) * mod;
    res %= mod;
    if(res < 0) res += mod; //误差
    return res;
}

template<ll P>
struct MLL{
    ll x;
    constexpr MLL() = default;
    constexpr MLL(ll x) : x(norm(x % getMod())) {}

    static ll Mod;
    constexpr static ll getMod(){
       if(P > 0) return P;
       return Mod;
    }

    constexpr static void setMod(int _Mod){
       Mod = _Mod;
    }
    constexpr ll norm(ll x) const{
       if(x < 0){
           x += getMod();
       }
       if(x >= getMod()){
           x -= getMod();
       }
       return x;
    }
    constexpr ll val() const{
       return x;
    }
    explicit constexpr operator ll() const{ 
       return x; //将结构体显示转换为ll类型: ll res = static_cast<ll>(OBJ)
    }
    constexpr MLL operator -() const{ //负号,等价于加上Mod
       MLL res;
       res.x = norm(getMod() - x);
       return res;
    }
    constexpr MLL inv() const{
       assert(x != 0);
       return power(*this, getMod() - 2); //用费马小定理求逆
    }
    constexpr MLL& operator *= (MLL rhs) & { //& 表示“this”指针不能指向一个临时对象或const对象
       x = mul(x, rhs.x, getMod()); //该函数只能被一个左值调用
       return *this;
    }
    constexpr MLL& operator += (MLL rhs) & {
       x = norm(x + rhs.x);
       return *this;
    }
    constexpr MLL& operator -= (MLL rhs) & {
       x = norm(x - rhs.x);
       return *this;
    }
    constexpr MLL& operator /= (MLL rhs) & {
       return *this *= rhs.inv();
    }
    friend constexpr MLL operator * (MLL lhs, MLL rhs){
       MLL res = lhs;
       res *= rhs;
       return res;
    }
    friend constexpr MLL operator + (MLL lhs, MLL rhs){
       MLL res = lhs;
       res += rhs;
       return res;
    }
    friend constexpr MLL operator - (MLL lhs, MLL rhs){
       MLL res = lhs;
       res -= rhs;
       return res;
    }
    friend constexpr MLL operator / (MLL lhs, MLL rhs){
       MLL res = lhs;
       res /= rhs;
       return res;
    }
    friend constexpr std::istream& operator >> (std::istream& is, MLL& a){
       ll v;
       is >> v;
       a = MLL(v);
       return is;
    }
    friend constexpr std::ostream& operator << (std::ostream& os, MLL& a){
       return os << a.val();
    }
    friend constexpr bool operator == (MLL lhs, MLL rhs){
       return lhs.val() == rhs.val();
    }
    friend constexpr bool operator != (MLL lhs, MLL rhs){
       return lhs.val() != rhs.val();
    }
};

const ll mod = 1000000009;
using Z = MLL<mod>;

char code[220]; //错误码
int cnt; //AC自动机节点数量

struct node{
    int son[2];
    bool end; //code结尾标记
    int fail;
}tree[2500];

int nxt[2500][10]; //Trie上的点往后移动后的节点编号

Z dp[220][2500];
int num[220];

void insert(char* s){
    int now = 0;
    int n = strlen(s);
    fore(i, 0, n){
        int ch = s[i] - '0';
        if(!tree[now].son[ch])
            tree[now].son[ch] = ++cnt;
        now = tree[now].son[ch];
    }
    tree[now].end = true;
}

void getFail(){
    std::queue<int> q;
    fore(i, 0, 2)
        if(tree[0].son[i])
            q.push(tree[0].son[i]);
    while(!q.empty()){
        int now = q.front();
        q.pop();
        tree[now].end |= tree[tree[now].fail].end; //code结尾标记向下传递
        fore(i, 0, 2){
            if(tree[now].son[i]){
                tree[tree[now].son[i]].fail = tree[tree[now].fail].son[i];
                q.push(tree[now].son[i]);
            }
            else tree[now].son[i] = tree[tree[now].fail].son[i];
        }
    }
}

int judge(int now, int d){ //d是十进制数,将其转成二进制后在Trie上移动
    for(int i = 3; i >= 0; --i){
        int ch = ((d >> i) & 1);
        if(tree[tree[now].son[ch]].end) return -1; //下一个位置是一个禁止码
        now = tree[now].son[ch]; //往下走
    }
    return now;
}

Z dfs(int pos, int now, bool lead, bool limit){
    if(!pos) return 1;
    if(!lead && !limit && dp[pos][now].x != -1) return dp[pos][now];
    Z res = 0;
    int up = (limit ? num[pos] : 9);
    fore(i, 0, up + 1){
        if(lead && !i) res += dfs(pos - 1, now, true, limit && i == up);
        else if(nxt[now][i] != -1) res += dfs(pos - 1, nxt[now][i], false, limit && i == up);
    }
    if(!lead && !limit) dp[pos][now] = res;
    return res;
}

Z solve(std::string& s){
    int len = 0;
    for(auto it = s.rbegin(); it != s.rend(); ++it){
        num[++len] = *it - '0';
    }
    return dfs(len, 0, true, true);
}

int main(){
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);
    int t;
    std::cin >> t;
    while(t--){
        fore(i, 0, 220)
            fore(j, 0, 2200)
                dp[i][j].x = -1;
        
        int n;
        std::cin >> n;
        while(n--){
           std::cin >> code; //读入错误码
           insert(code);
        }

        getFail();

        /* 预处理Trie上每个节点往后走的情况的节点编号,以及是否会走到某个code */
        fore(i, 0, cnt + 1)
            fore(j, 0, 10)
                nxt[i][j] = judge(i, j);

        std::string l, r;
        std::cin >> l >> r;

        /* l - 1 */
        for(int i = l.size() - 1; i >= 0; --i)
            if(l[i] == '0')
                l[i] = '9';
            else{
                --l[i];
                break;
            }

        Z ans = solve(r) - solve(l);
        std::cout << ans << endl;

        /* 重置 Trie */
        fore(i, 0, cnt + 1){
            tree[i].fail = tree[i].end = 0;
            fore(j, 0, 2) tree[i].son[j] = 0;
        }
        cnt = 0;
    }
    return 0;
}
文章来源:https://blog.csdn.net/m0_73500785/article/details/135717786
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。