Random Pick with Weight

Published: Dec 28, 2022

Medium Binary Search Prefix Sum

Problem Description

You are given a 0-indexed array of positive integers w where w[i] describes the weight of the i-th index.

You need to implement the function pickIndex(), which randomly picks an index in the range [0, w.length - 1] (inclusive) and returns it. The probability of picking an index i is w[i] / sum(w).

  • For example, if w = [1, 3], the probability of picking index 0 is 1 / (1 + 3) = 0.25 (i.e., 25%), and the probability of picking index 1 is 3 / (1 + 3) = 0.75 (i.e., 75%).

Constraints:

  • 1 <= w.length <= 10**4
  • 1 <= w[i] <= 105
  • pickIndex will be called at most 10**4 times.

https://leetcode.com/problems/random-pick-with-weight/

Examples

Example 1
Input
["Solution","pickIndex"]
[[[1]],[]]
Output
[null,0]

Explanation
Solution solution = new Solution([1]);
solution.pickIndex(); // return 0. The only option is to return 0 since there is only one element in w.
Example 2
Input
["Solution","pickIndex","pickIndex","pickIndex","pickIndex","pickIndex"]
[[[1,3]],[],[],[],[],[]]
Output
[null,1,1,1,1,0]

How to Solve

This problem requires some ideas. The prefix sum and binary search are a good approach. The prefix sum reflects the difference of each weight. Do binary search on the prefix sum to get the index of a randomly generated weight. It depends on the languages what values are generated by random function. For example, C++ doesn’t need to regularize the given weights, but others need to change values between 0 to 1.

Solution

class RandomPickWithWeight {
private:
    vector<int> weights;
public:
    Solution(vector<int>& w) {
        srand(time(NULL));
        weights.push_back(w[0]);
        for (int i = 1; i < w.size(); ++i) {
            weights.push_back(w[i] + weights.back());
        }
    }
    
    int pickIndex() {
        int weight = rand() % weights.back();
        int left = 0, right = weights.size() - 1;
        int mid;
        while (left < right) {
            mid = (left + right) / 2;
            if (weight >= weights[mid]) {
                left = mid + 1;
            } else {
                right = mid;
            }
        }
        return left;
    }
};


class RandomPickWithWeight:

    def __init__(self, w: List[int]):
        total = sum(w)
        self.weights = [w[0] / total]
        for v in w[1:len(w) - 1]:
            self.weights.append(self.weights[-1] + v / total)
        self.weights.append(1)

    def pickIndex(self) -> int:
        weight = random.random()
        left, right = 0, len(self.weights) - 1
        while left < right:
            mid = (left + right) // 2
            if weight >= self.weights[mid]:
                left = mid + 1
            else:
                right = mid
        return left
class RandomPickWithWeight

=begin
    :type w: Integer[]
=end
    def initialize(w)
        total = w.sum
        @weights = [w[0].to_f / total]
        (1...w.size - 1).each do |i|
            @weights << w[i].to_f / total + @weights[-1]
        end
        @weights << 1
    end


=begin
    :rtype: Integer
=end
    def pick_index()
        weight = rand
        left, right = 0, @weights.size - 1
        while left < right
            mid = (left + right) / 2
            if weight >= @weights[mid]
                left = mid + 1
            else
                right = mid
            end
        end
        left
    end
end

Complexities

  • Time: constructor: O(n), pickIndex: O(log(n))
  • Space: O(n)