n(n<=5e5)个数,第i个数ai(1<=ai<=1e9)
构造一个序列b,要求bi∈[1,ai],且b[i]不等于b[i+1]
求方案数,答案对998244353取模
洛谷题解Xu_brezza
一模一样的cf题:
Codeforces Round 759 (Div. 2, based on Technocup 2022 Elimination Round 3) F. Non-equal Neighbours
首先肯定是容斥,假设出现了一对冲突的就叫有一个坏点,
那么,答案=没有冲突的-至少一个冲突的+至少两个冲突的...
出现了一个坏点,就认为是合并减少了一个数,
所以最后如果减少了k个数,就认为序列被拆成了n-k段,且每段内的数字相同
dp[i][j]表示前i个数被划分成了j段的方案数,
其中每段内的数字是相同的,也就是从[1,这一段的最小值]中取
1. 朴素转移即枚举最后一段在哪,补上这最后一段的贡献,对应了这一个区间的最小值
复杂度
2. 注意到,第二维对容斥系数的贡献,只有第二维的奇偶性,所以可以改写为
前缀和分别维护第二维为奇数/为偶数的和,
复杂度
3. 单调栈优化,注意到,如果找到了前面第一个小于ai的数是p,
那么,ai对前面的位置的数,也就是k<p的位置不再有贡献,
考虑k<p的位置的和,
所以有,
那么,一边维护单调栈一边维护前缀和即可,
特别地,如果不存在这样的p,说明ai是前缀最小,
观察原式后代入ai,有
前0个数分成偶数段方案数为1,分成奇数段方案数为0,对应dp[0][0]=1
最后,答案=至少出现偶数次冲突-至少出现奇数次冲突方案
如果n为偶数,说明分成偶数段=出现偶数次冲突,答案是
否则,如果n为奇数,说明分成奇数段=出现偶数次冲突,答案是
复杂度
#include<iostream>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
typedef long long ll;
typedef double db;
typedef pair<int,int> P;
#define fi first
#define se second
#define pb push_back
#define dbg(x) cerr<<(#x)<<":"<<x<<" ";
#define dbg2(x) cerr<<(#x)<<":"<<x<<endl;
#define SZ(a) (int)(a.size())
#define sci(a) scanf("%d",&(a))
#define scll(a) scanf("%lld",&(a))
#define pt(a) printf("%d",a);
#define pte(a) printf("%d\n",a)
#define ptlle(a) printf("%lld\n",a)
#define debug(...) fprintf(stderr, __VA_ARGS__)
const int N=5e5+10,mod=998244353;
int n,a[N],dp[N][2],sum[N][2],stk[N],c,ans;
// dp[i][j&1]+=dp[p][j&1]+sum dp[k][j&1^1]*a[i],p为小于ai的最大位置
void add(int &x,int y){
x=(x+y)%mod;
}
int cal(int l,int r,int v){
if(l==0)return sum[r][v];
return (sum[r][v]-sum[l-1][v]+mod)%mod;
}
int main(){
sci(n);
dp[0][0]=sum[0][0]=1;
rep(i,1,n){
sci(a[i]);
while(c && a[stk[c]]>=a[i]){
c--;
}
rep(j,0,1){
if(c)dp[i][j]=(dp[stk[c]][j]+1ll*cal(stk[c],i-1,j^1)*a[i]%mod)%mod;
else dp[i][j]=1ll*cal(0,i-1,j^1)*a[i]%mod;
sum[i][j]=(sum[i-1][j]+dp[i][j])%mod;
}
stk[++c]=i;
//printf("i:%d dp:(%d,%d)\n",i,dp[i][0],dp[i][1]);
}
if(n&1)ans=(dp[n][1]-dp[n][0]+mod)%mod;
else ans=(dp[n][0]-dp[n][1]+mod)%mod;
printf("%d\n",ans);
return 0;
}
自己的乱搞,基于以下这个式子,构造单调栈优化
其中,dp[i]表示以i结尾的方案数,枚举最后一段合并了几个数,
假设最后一段x个数,说明冲突了x-1次,容斥系数是(-1)的x-1次方,加上对应的贡献即可
#include<iostream>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
typedef long long ll;
typedef double db;
typedef pair<int,int> P;
#define fi first
#define se second
#define pb push_back
#define dbg(x) cerr<<(#x)<<":"<<x<<" ";
#define dbg2(x) cerr<<(#x)<<":"<<x<<endl;
#define SZ(a) (int)(a.size())
#define sci(a) scanf("%d",&(a))
#define scll(a) scanf("%lld",&(a))
#define pt(a) printf("%d",a);
#define pte(a) printf("%d\n",a)
#define ptlle(a) printf("%lld\n",a)
#define debug(...) fprintf(stderr, __VA_ARGS__)
const int N=5e5+10,mod=998244353;
//dp[i]+=dp[j]*min(a[j+1],...,a[i])*(-1)^(i-j-1),j<i
//dp[1]+=dp[0]*a[1]
//dp[2]+=dp[1]*a[2]
//dp[2]-=dp[0]*min(a[1],a[2])
int n,a[N],dp[N],sum[2],sum2[N],stk[N],mn[N],cp[N][2],c;
void add(int &x,int y){
x=(x+y)%mod;
}
int main(){
sci(n);
rep(i,1,n){
sci(a[i]);
if(i==1)mn[i]=a[i];
else mn[i]=min(mn[i-1],a[i]);
}
rep(i,0,n){
while(c && a[stk[c]]>=a[i]){
int p=stk[c],x=(p-1)&1;
add(sum[x],mod-1ll*cp[p-1][0]*a[p]%mod);//j
add(sum2[x],cp[p-1][0]);//j
add(sum[x^1],mod-1ll*cp[p-1][1]*a[p]%mod);//j
add(sum2[x^1],cp[p-1][1]);//j
c--;
}//1 2 4 6 5 7
//printf("i:%d sum:(%d,%d) sum2:(%d,%d)\n",i,sum[0],sum[1],sum2[0],sum2[1]);
int v1=(sum[i&1^1]-sum[i&1]+mod)%mod;
int v2=1ll*(sum2[i&1^1]-sum2[i&1]+mod)*a[i]%mod;
dp[i]=(v1+v2)%mod;
if(i&1)add(dp[i],mn[i]);//dp[0]的贡献
else add(dp[i],mod-mn[i]);
//printf("i:%d dp:%d\n",i,dp[i]);
int x=(i-1)&1;
stk[++c]=i;
cp[i-1][0]=sum2[x];
cp[i-1][1]=sum2[x^1];
add(sum[x],1ll*cp[i-1][0]*a[i]%mod);
add(sum[x^1],1ll*cp[i-1][1]*a[i]%mod);
sum2[0]=sum2[1]=0;
while(c && a[stk[c]]>=a[i+1]){
int p=stk[c],x=(p-1)&1;
add(sum[x],mod-1ll*cp[p-1][0]*a[p]%mod);//j
add(sum2[x],cp[p-1][0]);//j
add(sum[x^1],mod-1ll*cp[p-1][1]*a[p]%mod);//j
add(sum2[x^1],cp[p-1][1]);//j
c--;
}
stk[++c]=i+1;
x=i&1;
cp[i][0]=(sum2[x]+dp[i])%mod;
cp[i][1]=sum2[x^1];
sum2[0]=sum2[1]=0;
add(sum[x],1ll*cp[i][0]*a[i+1]%mod);
add(sum[x^1],1ll*cp[i][1]*a[i+1]%mod);
}
printf("%d\n",dp[n]);
return 0;
}