Problem Link

Description


Given a (0-indexed) integer array nums and two integers low and high, return the number of nice pairs.

A nice pair is a pair (i, j) where 0 <= i < j < nums.length and low <= (nums[i] XOR nums[j]) <= high.

 

Example 1:

Input: nums = [1,4,2,7], low = 2, high = 6
Output: 6
Explanation: All nice pairs (i, j) are as follows:
    - (0, 1): nums[0] XOR nums[1] = 5 
    - (0, 2): nums[0] XOR nums[2] = 3
    - (0, 3): nums[0] XOR nums[3] = 6
    - (1, 2): nums[1] XOR nums[2] = 6
    - (1, 3): nums[1] XOR nums[3] = 3
    - (2, 3): nums[2] XOR nums[3] = 5

Example 2:

Input: nums = [9,8,4,2,1], low = 5, high = 14
Output: 8
Explanation: All nice pairs (i, j) are as follows:
​​​​​    - (0, 2): nums[0] XOR nums[2] = 13
    - (0, 3): nums[0] XOR nums[3] = 11
    - (0, 4): nums[0] XOR nums[4] = 8
    - (1, 2): nums[1] XOR nums[2] = 12
    - (1, 3): nums[1] XOR nums[3] = 10
    - (1, 4): nums[1] XOR nums[4] = 9
    - (2, 3): nums[2] XOR nums[3] = 6
    - (2, 4): nums[2] XOR nums[4] = 5

 

Constraints:

  • 1 <= nums.length <= 2 * 104
  • 1 <= nums[i] <= 2 * 104
  • 1 <= low <= high <= 2 * 104

Solution


Python3

class Solution:
    def countPairs(self, nums: List[int], low: int, high: int) -> int:
        N = len(nums)
        MAX_BIT = 15
        root = [None, None, 0]
        COUNT = 2
        
        def add(x):
            curr = root
            
            for k in range(MAX_BIT, -1, -1):
                if x & (1 << k):
                    bit = 1
                else:
                    bit = 0
                
                if curr[bit] is None:
                    curr[bit] = [None, None, 0]
                
                curr = curr[bit]
                curr[COUNT] += 1
        
        def query(x, limit):
            curr = root
            res = 0
            
            for k in range(MAX_BIT, -1, -1):
                if x & (1 << k):
                    x_bit = 1
                else:
                    x_bit = 0
                
                if limit & (1 << k):
                    l_bit = 1
                else:
                    l_bit = 0
                
                # the trie below is not constructed yet
                if curr is None: break
                
                # when l_bit == 0
                # x_bit ^ x_bit = l_bit (0)
                if l_bit == 0:
                    curr = curr[x_bit]
                    continue
                
                # when l_bit == 1
                # to achieve the sum less than limit,
                # find the same x_bit to achieve XOR of 0
                if curr[x_bit] is not None:
                    res += curr[x_bit][COUNT]
                
                # explore the opposite bit
                # since we may have taken the pairs as the condition above
                curr = curr[1 - x_bit]
                
            return res
    
        ans = 0
        for x in nums:
            ans += query(x, high + 1) - query(x, low)
            add(x)
        
        return ans