Problem

Minimum Cost to Merge Stones

Approach

  1. First of all, we have to check if these given stones can be merged into one pile. The first time of merge is to remove K piles of stones. The 2nd time of merge is to remove K - 1 piles and one pile merged from previous operation and so on. So, we only need to check if N - K is multiples of K - 1.
  2. Then, let dp[i][len] is the minimum cost to merge stones, stones[i, i + 1, …, i + len - 1].
    • Initialize all elements in dp with Integer.MAX_VALUE that will be useful to get the minimum answer.
    • Especially, if 1 <= len < K, dp[i][len] = 0. The reason of doing this is that when we calculate some dp, for example, dp[i][len] = dp[i][part] + dp[i+part][len-part], in which the len is larger than K but smaller than 2*K - 1, some 0 values can be applied to the dp[i+part][len-part] and we can put the dp[i][part] value into dp[i][len]
    • Then, dp[i][len] = min(dp[i][len], dp[i][span] + dp[i+span][len-span]), 1 <= span < len
    • More imporantly, after update dp[i][len] by iteration over other dp, we have to put current cost, which is the sum from index i to index i+len, into dp[i][len]
  3. Lastly, the answer is dp[0][N]

Code

class Solution {
    public int mergeStones(int[] stones, int K) {
        int N = stones.length;
        if(N < K){
            return 0;
        }
        if((N - K) % (K - 1) != 0){
            return -1;
        }
        int[][] dp = new int[N][N + 1];
        for(int i = 0; i < N; i ++){
            for(int j = 0; j < N + 1; j ++){
                dp[i][j] = Integer.MAX_VALUE;
            }
        }
        int[] sum = new int[N];
        sum[0] = stones[0];
        for(int i = 1; i < N; i ++){
            sum[i] = sum[i - 1] + stones[i];
        }
        for(int i = 0; ; i ++){
            int j = i + K - 1;
            if(j >= N){
                break;
            }
            dp[i][K] = sum[j];
            if(i >= 1){
                dp[i][K] -= sum[i - 1];
            }
        }
        for(int i = 0; i < N; i ++){
            for(int j = 1; j < K; j ++){
                dp[i][j] = 0;
            }
        }
        for(int span = K + 1; span <= N; span ++){
            for(int i = 0; i < N; i ++){
                int j = i + span - 1;
                if(j >= N){
                    break;
                }
                // from i to j inclusive
                for(int k = 1; k < span; k += K - 1){
                    if(dp[i][k] < Integer.MAX_VALUE && dp[i + k][span - k] < Integer.MAX_VALUE){
                        dp[i][span] = Math.min(dp[i][span], dp[i][k] + dp[i + k][span - k]);
                    }
                }
                if((span - K) % (K - 1) == 0){
                    dp[i][span] += sum[j];
                    if(i >= 1){
                        dp[i][span] -= sum[i - 1];
                    }
                }
            }
        }
        return dp[0][N];
    }
}