1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
| #include <bits/stdc++.h> using namespace std; #define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout) #define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++) #define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--) #define DEBUG(x) cout<<#x<<"="<<x<<endl; #define all(x) (x).begin(),(x).end() #define cle(x) memset(x,0,sizeof(x)) #define ll long long #define ull unsigned ll #define db double #define lb long db #define pb push_back #define mp make_pair #define fi first #define se second inline int read() { int x=0; char ch=getchar(); bool f=0; for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1; for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48); return f?-x:x; } const int N=(1<<20)+5; const ll mod=998244353ll; const ll G=3; inline ll Add(ll x,ll y){x+=y; return (x<mod)?x:x-mod;} inline ll Dec(ll x,ll y){x-=y; return (x<0)?x+mod:x;} inline ll Mul(ll x,ll y){return x*y%mod;} inline ll Pow(ll x,ll y){y%=(mod-1);ll ans=1;for(;y;y>>=1,x=x*x%mod)if(y&1) ans=ans*x%mod;return ans;} int R[N]; ll A[N],B[N]; int pre_ntt(int n) { int m,L; for(m=1,L=0;m<=n;m<<=1) L++; fo(i,1,m-1) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1)); return m; } inline void ntt(ll *a,int len,int opt) { for(int i=1;i<len;i++) if(i>R[i]) swap(a[i],a[R[i]]); for(int i=1;i<len;i<<=1) { ll wn=Pow(G,(mod-1)/(i<<1)); if(opt==-1) wn=Pow(wn,mod-2); for(int j=0;j<len;j+=(i<<1)) { ll w=1,x,y; for(int k=0;k<i;k++,w=w*wn%mod) x=a[j+k],y=a[i+j+k]*w%mod, a[j+k]=(x+y)%mod,a[i+j+k]=(x-y+mod)%mod; } } if(opt==1) return; ll invn=Pow(len,mod-2); for(int i=0;i<=len;i++) a[i]=a[i]*invn%mod; } inline void Pmul(ll *c,ll *a,ll *b,int na,int nb,int &k) { int len=pre_ntt(na+nb); k=na+nb; fo(i,0,na) A[i]=a[i]; fo(i,na+1,len) A[i]=0; fo(i,0,nb) B[i]=b[i]; fo(i,nb+1,len) B[i]=0; ntt(A,len,1); ntt(B,len,1); fo(i,0,len) A[i]=A[i]*B[i]%mod; ntt(A,len,-1); fo(i,0,k) c[i]=A[i]; fo(i,0,len) A[i]=B[i]=0; } void Pinv(ll *a,ll *b,int n) { if(n==1) return (void)(b[0]=Pow(a[0],mod-2)); Pinv(a,b,(n+1)>>1); int len=pre_ntt(n<<1); fo(i,0,n-1) A[i]=a[i]; fo(i,0,((n+1)>>1)-1) B[i]=b[i]; ntt(A,len,1); ntt(B,len,1); fo(i,0,len-1) B[i]=(2ll-A[i]*B[i]%mod+mod)%mod*B[i]%mod; ntt(B,len,-1); fo(i,0,n-1) b[i]=B[i]; fo(i,0,len) A[i]=B[i]=0; }
int n; ll fac[N],inv[N],a[N],b[N],c[N],g[N],h[N],ans; int p[N]; int main() { n=read(); int m; fac[0]=1; fo(i,1,n) fac[i]=fac[i-1]*i%mod; inv[n]=Pow(fac[n],mod-2); fd(i,n,1) inv[i-1]=inv[i]*i%mod; fo(i,0,n) a[i]=Pow(2,1ll*i*(i-1)/2)*inv[i]%mod; Pinv(a,h,1+n); fo(i,0,n) b[i]=h[i]*a[n-i]%mod*fac[n-i]%mod; fo(i,0,n) c[i]=inv[i]; Pmul(g,b,c,n,n,m); fo(i,0,n) g[i]=g[i]*fac[i]%mod; fo(i,1,n) p[i]=read(); sort(p+1,p+n+1); fo(i,1,n) ans=Add(ans,Mul(p[i],g[i-1]-g[i]+mod)); printf("%lld\n",ans); return 0; }
|