ARC115 E

AtCoder

Problem Statement

Given is a sequence of NN integers A1A_1,A2A_2,...,ANA_N. Print the number, modulo 998244353998244353, of sequence of NN integers X1X_1,X2X_2,...,XNX_N satisfying all of the following conditions:
1XiAi1 \leq X_i \leq A_i

  • Xi=Xi+1(1iN1)X_i \not = X_{i+1} (1 \leq i \leq N-1)

Constraints

  • 2N51052 \leq N \leq 5 * {10}^5
  • 1Ai1091 \leq A_i \leq {10}^9

Input

Input is given from Standard Input in the following format:

NA1A2...ANN \\ A_1 A_2 ... A_N

Output

Print the answer.

解法

题意说n个位置,每个位置上面的数字不能大于Ai,问每对相邻的数都不相同的序列数有多少个。这种问题一看就是容斥,用所有的减去不符合的。不符合的分为至少某1个位置不符合,至少某2个位置不符合……这样就可以用dp去做了。
dp[i][j]代表前i个元素分为j段的方案数,使得每段内的所有元素都相等。那么答案其实就是dp[n][n]-dp[n][n-1]+dp[n][n-2]...。这个转移方程是显然的:

dp[i][j]=ki1dp[k][j1]mink+1liAkdp[i][j] = \sum_{k \leq i-1} dp[k][j-1] * \min_{k+1 \leq l \leq i} A_k

但是这个转移怎么看都要O(N2)O(N^2),不过好在最终的容斥式子系数只与j的奇偶性有关,于是只考虑奇偶性转移:

dp[i][1]=ki1dp[k][0]mink+1liAkdp[i][0]=ki1dp[k][1]mink+1liAkdp[i][1] = \sum_{k \leq i-1} dp[k][0] * \min_{k+1 \leq l \leq i} A_k \\ dp[i][0] = \sum_{k \leq i-1} dp[k][1] * \min_{k+1 \leq l \leq i} A_k

现在就是怎么去做这个转移的问题了。考虑到某个AiA_i的时候,AiA_i作为新的一段的转移,AiA_i是新的一段中的最小值:

[A1 , ... Aj] [Aj+1 ... Ai ... Ak]

这里[Aj+1...Ai...Ak][A_{j+1} ... A_i ... A_k]是新添加上去的一段,可以发现l(i)j<il(i) \leq j \lt iikr(i)i \leq k \leq r(i),其中l(i)l(i)AiA_i左边第一个小于等于AiA_i的下标,r(i)r(i)AiA_i右边第一个小于AiA_i的下标。这里用了不同的符号是规定同样大的数字,前面的更小,防止重复更新同一段。这么规定也不会漏掉,因为每一个新增的段一定有一个AiA_i会被我们遍历到。l(i)l(i)r(i)r(i)可以通过单调栈轻松计算,这里不赘述。
这样的话,如果我们维护了当前元素左边所有dp值的前缀和,那么我们就可以快速获得所有满足条件的jj的dp和,然后更新到这一段的末尾可能的取值,即kk的范围:

rangeAdd(i+1, r[i], 0, preSum(l[i], i-1, 1) * A[i]);
rangeAdd(i+1, r[i], 1, preSum(l[i], i-1, 0) * A[i]);

这里的preSum(l, r, p)lirdp[i][p]\sum_{l \leq i \leq r} dp[i][p]rangeAdd可以通过数据结构维护,这里我们采用数组这种快速的数组结构来维护它:

rangeAdd(l, r, p, v) => diff[l][p] += v, diff[r+1][p] -= v;

这样这道题就做完了,以下是AC代码:

#pragma GCC optimize ("Ofast,unroll-loops")
#pragma GCC optimize("no-stack-protector,fast-math")

#include <bits/stdc++.h>

using namespace std;

constexpr int N = 5e5+7;
constexpr int M = 998244353;

#define fastio ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
#define int long long
#define pii pair<int, int>
#define fi first
#define se second
#define SZ(x) ((int)(x.size()))

#ifdef int
#define INF 0x3f3f3f3f3f3f3f3f
#define INF2 (int)(0xcfcfcfcfcfcfcfcf)
#else
#define INF 0x3f3f3f3f
#define INF2 0xcfcfcfcf
#endif


signed main() {
    fastio
    int n;
    cin >> n;
    vector<int> a(n+1, 0);
    for (int i = 1; i <= n; i++) cin >> a[i];
    auto add = [&](int& x, int y) {
        x = (x % M + M + y % M) % M;
    };
    vector<int> l(n+1, 0), r(n+1, n+1);
    vector<int> st;
    for (int i = 1; i <= n; i++) {
        while (!st.empty() and a[st.back()] > a[i])
            st.pop_back();
        l[i] = st.empty() ? 0 : st.back();
        st.emplace_back(i);
    }
    st.clear();
    for (int i = n; i >= 1; i--) {
        while (!st.empty() and a[st.back()] >= a[i])
            st.pop_back();
        r[i] = st.empty() ? n+1 : st.back();
        st.emplace_back(i);
    }
    st.clear();

    vector<array<int, 2>> dp(n+1, {0, 0});
    vector<array<int, 2>> sum(n+1, {0, 0});
    vector<array<int, 2>> diff(n+2, {0, 0});
    dp[0][0] = 1;
    sum[0][0] = dp[0][0];
    auto rangeAdd = [&](int l, int r, int parity, int v) {
        add(diff[l][parity], v);
        add(diff[r+1][parity], -v);
    };
    auto preSum = [&](int l, int r, int parity) {
        return (sum[r][parity] - (l ? sum[l-1][parity] : 0ll) + M) % M;
    };
    for (int i = 1; i <= n; i++) {
        int ll = l[i], lr = i-1;
        int rl = i, rr = r[i]-1;

        add(diff[i][0], diff[i-1][0]);
        add(diff[i][1], diff[i-1][1]);

        if (ll <= lr and rl <= rr) {
            rangeAdd(rl, rr, 0, preSum(ll, lr, 1) * a[i]);
            rangeAdd(rl, rr, 1, preSum(ll, lr, 0) * a[i]);
        }

        add(dp[i][0], diff[i][0]);
        add(dp[i][1], diff[i][1]);

        add(sum[i][0], sum[i-1][0]);
        add(sum[i][1], sum[i-1][1]);
        add(sum[i][0], dp[i][0]);
        add(sum[i][1], dp[i][1]);
    }

    cout << (dp[n][n&1] - dp[n][1^(n&1)] + M)%M << "\n";
    return 0;
}