跳转至

斯坦纳树

引入

斯坦纳树有点类似与一个可以新加入点的最小生成树, 最小生成树要求是在给的的边权之中连接给定的点集, 而斯坦纳树对于边集没有限制,只需要使给定的自己联通即可:

引入例题

给定一个包含 \(n\) 个结点和 \(m\) 条带权边的无向连通图 \(G=(V,E)\)

再给定包含 \(k\) 个结点的点集 \(S\),选出 \(G\) 的子图 \(G'=(V',E')\),使得:

  1. \(S\subseteq V'\)

  2. \(G'\) 为连通图;

  3. \(E'\) 中所有边的权值和最小。

你只需要求出 \(E'\) 中所有边的权值和。

对于 \(100\%\) 的数据,\(1\leq n\leq 100,\ \ 1\leq m\leq 500,\ \ 1\leq k\leq 10,\ \ 1\leq u,v\leq n,\ \ 1\leq w\leq 10^6\)

保证给出的无向图连通,但 可能 存在重边和自环。

但是其实斯坦纳树的正解也没有什么特殊的,就是 DP动态规划 。

最开始我的想法是直接令 \(f(G)\) 为联通子集 \(G\) 的代价,但是此时当我们合并两个集合时完全不知道代价,所以我们对于每一个联通子集 \(G\) 寻找一个代表元素 \(i\) ,这个 \(i\) 不一定在 \(G\) 中但是能与 \(G\) 的每一个点联通。

所有我们令 \(f(i,G)\) 为令子集 \(G\) 联通,并且 \(G\) 的所有点能够联通 \(i\) 的最小代价。 此时有两个转移方式:

  • \(f(i,S) + f(i,T-S) \to f(i,T) ~~~~(S\in T)\) (合并两个子集)

  • \(f(i,S) + w(i,j) \to f(j,S)\) (单独拓展一个集合)

但是需要注意这里的转移顺序,对于每一个集合 \(S\) ,我们应当先把合并操作做完(因为合并指挥由其子集而来),然后再对于 \(S\) 之间转移。

对于第一个,直接暴力枚举其子集,时间复杂度证明如下:

时间复杂度证明

对于外层枚举 \(S\) 时间复杂度显然是 \(\mathcal O(2^n)\)

然后假设 \(S\) 集合 \(1\) 的个数为 \(|S|\) ,那么内层枚举子集个数为 \(\frac{2^{|S|}-2}{2}=2^{|S|-1}-1\)

考虑枚举 \(|S|\) ,那么总时间复杂度为:

$\sum_{m = 1}^k \binom{k}{m} (2^{m-1} - 1) $

$= \frac12 \sum_{m = 1}^k \binom{k}{m} 2^m - \sum_{m = 1}^k \binom{k}{m} $

$= \frac12 \big( (1+2)^k - 1 \big) - \big( 2^k - 1 \big) $

$= \frac12 (3^k - 1) - (2^k - 1) $

$= \frac12 (3^k - 1) - 2^k + 1 $

\(= \frac12 3^k - 2^k + \frac12\)

大致时间复杂度 \(O(n*3^n)\)

后半部分直接写一个最短路就可以了。

示例代码
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <bits/stdc++.h>
using namespace std;
/*~~~~~~~~~~~~~~~~~~~~ Boundary Line ~~~~~~~~~~~~~~~~~~~~*/
const int inf = 0x7f7f7f7f;
const int N=105, M=1030;
int n,m,k;

int f[N][M];
vector< pair<int,int> > v[N];
/*~~~~~~~~~~~~~~~~~~~~ Boundary Line ~~~~~~~~~~~~~~~~~~~~*/
bool vis[N];
void spfa(int dep){
    queue< int > q;
    for(int i=1; i<=n; i++){
        vis[i]=0;
        if(f[i][dep]<inf) q.push(i);
    }

    while(!q.empty()){
        int x=q.front(); q.pop();
        vis[x]=0;

        for(auto to: v[x]){
            int y=to.first, w=to.second;
            if(f[y][dep]>f[x][dep]+w){
                f[y][dep]=f[x][dep]+w;
                if(vis[y]==0){
                    vis[y]=1;
                    q.push(y);
                }
            }
        }
    }
}


/*~~~~~~~~~~~~~~~~~~~~ Boundary Line ~~~~~~~~~~~~~~~~~~~~*/
signed main() {
    cin>>n>>m>>k;

    for(int i=1; i<=n; i++){
        for(int j=0; j<(1<<k); j++)
            f[i][j]=inf;
    }

    for(int i=1; i<=m; i++){
        int x,y,w; cin>>x>>y>>w;
        v[x].push_back({y,w});
        v[y].push_back({x,w});
    }

    for(int i=1; i<=k; i++){
        int x; cin>>x;
        f[x][(1<<(i-1))]=0;
    }

    for(int i=1; i<(1<<k); i++){
        for(int j=(i&(i-1)); j; j=(i&(j-1))){
            if(j > (i^j)) continue;
            for(int u=1; u<=n; u++)
                f[u][i]=min(f[u][i],f[u][j]+f[u][i^j]);
        }
        spfa(i);
    }

    int ans=inf;
    for(int i=1; i<=n; i++)
        ans=min(ans,f[i][(1<<k)-1]);

    cout<<ans<<'\n';

    return 0;
}

扩展

对于需要输出方案的情况,记录每一个转移,然后倒着 DFS 。

例题: P4294 [WC2008] 游览计划

示例代码
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <bits/stdc++.h>
#define int long long
#define PII pair<int,int>
using namespace std;
/*~~~~~~~~~~~~~~~~~~~~ Boundary Line ~~~~~~~~~~~~~~~~~~~~*/
const int inf = 0x7f7f7f7f;
const int N=105;
int n,m,k;

int a[15][15];
int f[N][1030];
PII pre[N][1030];

const int dx[N]={0,0,-1,1};
const int dy[N]={-1,1,0,0};
/*~~~~~~~~~~~~~~~~~~~~ Boundary Line ~~~~~~~~~~~~~~~~~~~~*/
#define pos(x,y) (((x)-1)*m+(y))
#define sop(x) (pair<int,int>((x-1)/m+1,(x-1)%m+1))

bool vis[N];
void spfa(int dep){
    queue<PII> q;
    for(int i=1; i<=n*m; i++){
        vis[i]=0;
        if(f[i][dep]<inf) q.push(sop(i));
    }

    while(!q.empty()){
        int x=q.front().first, y=q.front().second; 
        q.pop(); vis[pos(x,y)]=0;

        for(int i=0; i<4; i++){
            int nx=x+dx[i], ny=y+dy[i], w=a[nx][ny];
            if(nx<1 || ny<1 || nx>n || ny>m) continue;
            if(f[pos(nx,ny)][dep] > f[pos(x,y)][dep] + w){
                f[pos(nx,ny)][dep] = f[pos(x,y)][dep] + w;
                if(vis[pos(nx,ny)]==0){
                    vis[pos(nx,ny)]=1;
                    q.push({nx,ny});
                }
                pre[pos(nx,ny)][dep]=make_pair(pos(x,y),dep);
            }
        }
    }
}

int ans[N];
void Get_ans(int x,int y){
    if(!pre[x][y].first) return;
    ans[x]=1;
    if(pre[x][y].first==x) {
        Get_ans(x, y^pre[x][y].second);
    }
    Get_ans(pre[x][y].first, pre[x][y].second);
}

/*~~~~~~~~~~~~~~~~~~~~ Boundary Line ~~~~~~~~~~~~~~~~~~~~*/
signed main(){
    cin>>n>>m;

    for(int i=1; i<=n*m; i++){
        for(int j=0; j<=1025; j++)
            f[i][j]=inf;
    }

    int rt=-1;

    for(int i=1; i<=n; i++){
        for(int j=1; j<=m; j++){
            cin>>a[i][j];
            if(a[i][j]==0) {
                k++, rt=pos(i,j);
                f[pos(i,j)][1<<(k-1)]=0;
            }
        }
    }
    if(!k){
        cout<<0<<'\n';
        for(int i=1; i<=n; i++){
            for(int j=1; j<=m; j++) cout<<"_";
            cout<<'\n';
        }
        return 0;
    }
    for(int i=1; i<(1<<k); i++){
        for(int j=(i&(i-1)); j; j=(i&(j-1))){
            if(j>(i^j)) continue;
            for(int u=1; u<=n*m; u++){
                int x=sop(u).first, y=sop(u).second;
                if(f[u][i] > f[u][j]+f[u][i^j]-a[x][y]){
                    f[u][i] = f[u][j]+f[u][i^j]-a[x][y];
                    pre[u][i]=make_pair(u,j); 
                }
            }
        }
        spfa(i);
    }

    cout<<f[rt][(1<<k)-1]<<'\n';

    Get_ans(rt,(1<<k)-1);

    for(int i=1; i<=n; i++){
        for(int j=1; j<=m; j++){
            if(a[i][j]==0) cout<<'x';
            else if(ans[pos(i,j)]==1) cout<<'o';
            else cout<<'_';
        }
        cout<<'\n';
    }

    return 0;
}