Skip to content

Files

165 lines (123 loc) · 4.97 KB

File metadata and controls

165 lines (123 loc) · 4.97 KB

English Version

题目描述

给你一个下标从 0 开始的非负整数数组 nums 和两个整数 l 和 r 。

请你返回 nums 中子多重集合的和在闭区间 [l, r] 之间的 子多重集合的数目

由于答案可能很大,请你将答案对 109 + 7 取余后返回。

子多重集合 指的是从数组中选出一些元素构成的 无序 集合,每个元素 x 出现的次数可以是 0, 1, ..., occ[x] 次,其中 occ[x] 是元素 x 在数组中的出现次数。

注意:

  • 如果两个子多重集合中的元素排序后一模一样,那么它们两个是相同的 子多重集合 。
  •  集合的和是 0 。

 

示例 1:

输入:nums = [1,2,2,3], l = 6, r = 6
输出:1
解释:唯一和为 6 的子集合是 {1, 2, 3} 。

示例 2:

输入:nums = [2,1,4,2,7], l = 1, r = 5
输出:7
解释:和在闭区间 [1, 5] 之间的子多重集合为 {1} ,{2} ,{4} ,{2, 2} ,{1, 2} ,{1, 4} 和 {1, 2, 2} 。

示例 3:

输入:nums = [1,2,1,3,5,2], l = 3, r = 5
输出:9
解释:和在闭区间 [3, 5] 之间的子多重集合为 {3} ,{5} ,{1, 2} ,{1, 3} ,{2, 2} ,{2, 3} ,{1, 1, 2} ,{1, 1, 3} 和 {1, 2, 2} 。

 

提示:

  • 1 <= nums.length <= 2 * 104
  • 0 <= nums[i] <= 2 * 104
  • nums 的和不超过 2 * 104
  • 0 <= l <= r <= 2 * 104

解法

Python3

class Solution:
    def countSubMultisets(self, nums: List[int], l: int, r: int) -> int:
        kMod = 1_000_000_007
        # dp[i] := # of submultisets of nums with sum i
        dp = [1] + [0] * r
        count = collections.Counter(nums)
        zeros = count.pop(0, 0)

        for num, freq in count.items():
            # stride[i] := dp[i] + dp[i - num] + dp[i - 2 * num] + ...
            stride = dp.copy()
            for i in range(num, r + 1):
                stride[i] += stride[i - num]
            for i in range(r, 0, -1):
                if i >= num * (freq + 1):
                    # dp[i] + dp[i - num] + dp[i - freq * num]
                    dp[i] = stride[i] - stride[i - num * (freq + 1)]
                else:
                    dp[i] = stride[i]

        return (zeros + 1) * sum(dp[l : r + 1]) % kMod

Java

class Solution {
    static final int MOD = 1_000_000_007;
    public int countSubMultisets(List<Integer> nums, int l, int r) {
        Map<Integer, Integer> count = new HashMap<>();
        int total = 0;
        for (int num : nums) {
            total += num;
            if (num <= r) {
                count.merge(num, 1, Integer::sum);
            }
        }
        if (total < l) {
            return 0;
        }
        r = Math.min(r, total);
        int[] dp = new int[r + 1];
        dp[0] = count.getOrDefault(0, 0) + 1;
        count.remove(Integer.valueOf(0));
        int sum = 0;
        for (Map.Entry<Integer, Integer> e : count.entrySet()) {
            int num = e.getKey();
            int c = e.getValue();
            sum = Math.min(sum + c * num, r);
            // prefix part
            // dp[i] = dp[i] + dp[i - num] + ... + dp[i - c*num] + dp[i-(c+1)*num] + ... + dp[i % num]
            for (int i = num; i <= sum; i++) {
                dp[i] = (dp[i] + dp[i - num]) % MOD;
            }
            int temp = (c + 1) * num;
            // correction part
            // subtract dp[i - (freq + 1) * num] to the end part.
            // leves dp[i] = dp[i] + dp[i-num] +...+ dp[i - c*num];
            for (int i = sum; i >= temp; i--) {
                dp[i] = (dp[i] - dp[i - temp] + MOD) % MOD;
            }
        }
        int ans = 0;
        for (int i = l; i <= r; i++) {
            ans += dp[i];
            ans %= MOD;
        }
        return ans;
    }
}

C++

Go

...