[前题提要]:感觉dp还是很显然的,感觉单调栈也不是很难想,但是VP的时候脑子比较乱,dp方程想偏了,没写出来…
看完题目,不难发现应该存在一种递推关系.因为会发现最后剩下来的必然是原序列的一种子序列,然后这种子序列计数的问题.应该想到使用dp计数.
刚开始我的想法是使用
d
p
[
i
]
dp[i]
dp[i]记录前
i
i
i位的方案数,然后发现这种方式极难递推.给两个
h
a
c
k
hack
hack样例:
2 4 1
3 2 1
发现光光使用上面的dp方程,对于上述两种样例没办法区分.
所以考虑优化一下我们的dp方程,使用
d
p
[
i
]
dp[i]
dp[i]来记录前
i
i
i位并且最后一位是
i
i
i的子序列个数.
然后想一下该怎么进行转移,对于当前位置
i
i
i,如果它能从
j
j
j位置转移过来,需要满足什么条件.
首先我们区间 [ j + 1 , i ? 1 ] [j+1,i-1] [j+1,i?1]的所有数应该是小于我们的 a [ j ] , a [ i ] a[j],a[i] a[j],a[i]的,这样我们才能使用端点将区间中的所有数字删掉.也就是说,我们的所有能转移的 j 1 , j 2 , j 3 . . . j n j_1,j_2,j_3...j_n j1?,j2?,j3?...jn?必须满足 a [ j 1 ] < a [ j 2 ] < . . . < a [ j n ] < a [ i ] a[j_1]<a[j_2]<...<a[j_n]<a[i] a[j1?]<a[j2?]<...<a[jn?]<a[i]
不难发现,我们的 a [ j n ] a[j_n] a[jn?]应该是 i i i位置左边第一个小于 a [ i ] a[i] a[i]的数.这个我们可以使用单调栈预处理出来(使用单调栈倒推即可,此处不在赘述).然后我们会发现 [ j n , i ? 1 ] [j_n,i-1] [jn?,i?1]的所有数字都可以转移过来(中间数被 a [ i ] a[i] a[i]位置的数字删掉),然后对于上述的 j 1 j_1 j1?~ j n ? 1 j_{n-1} jn?1?(注意此处是单点),也都可以转移过来.并且除了上述的点以外,没有一个点能转移到 i i i,因为对于其他点来说,区间中一定有一个点的值比端点小.
对于区间
[
j
n
+
1
,
i
]
[j_n+1,i]
[jn?+1,i]的贡献,我们可以使用前缀和处理一下
对于所有单点贡献
∑
d
p
[
j
i
]
\sum dp[j_i]
∑dp[ji?],我们也可以使用前缀和处理,只不过此时的前缀和的下标变了而已.
需要注意的是我们最后剩下来的数字必然是后缀最小值,所以最终我们的答案是就是所有后缀最小值的贡献和
下面是具体的代码部分:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define root 1,n,1
#define ls (rt<<1)
#define rs (rt<<1|1)
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
inline ll read() {
ll x=0,w=1;char ch=getchar();
for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x*w;
}
inline void print(__int128 x){
if(x<0) {putchar('-');x=-x;}
if(x>9) print(x/10);
putchar(x%10+'0');
}
#define maxn 1000000
#define int long long
const int mod=998244353;
const double eps=1e-8;
#define int_INF 0x3f3f3f3f
#define ll_INF 0x3f3f3f3f3f3f3f3f
int p[maxn];int dp[maxn];int sum1[maxn],sum2[maxn];
signed main() {
int T=read();
while(T--) {
int n=read();map<int,int>pos;
for(int i=1;i<=n;i++) {
p[i]=read();pos[p[i]]=i;
}
stack<int>s;map<int,int>mp;
s.push(p[n]);
for(int i=n-1;i>=1;i--) {
while(!s.empty()&&p[i]<s.top()) {
mp[s.top()]=i;
s.pop();
}
s.push(p[i]);
}
dp[1]=1;sum1[1]=1;sum2[1]=1;
for(int i=2;i<=n;i++) {
dp[i]=(sum1[i-1]-sum1[mp[p[i]]]+mod)%mod;
dp[i]=(dp[i]+sum2[mp[p[i]]])%mod;
if(mp[p[i]]==0) dp[i]=(dp[i]+1)%mod;
sum2[i]=(sum2[mp[p[i]]]+dp[i])%mod;
sum1[i]=(sum1[i-1]+dp[i])%mod;
}
int minn=p[n];
int ans=0;
for(int i=n;i>=1;i--) {
minn=min(minn,p[i]);
if(minn==p[i]) {
ans=(ans+dp[i])%mod;
}
}
cout<<ans<<endl;
for(int i=1;i<=n;i++) {
dp[i]=0;sum1[i]=sum2[i]=0;
}
}
return 0;
}