ALGORITHM September 03, 2021

Tìm trung vị của 2 dãy đã được sắp xếp

Số chữ 16k Thời gian đọc 14 mins.

Cho hai dãy nums1, nums2 đã được sắp xếp theo thứ tự có size là mn, hãy tìm trung vị (median) của hai dãy đó.

Đây là bài toán được xem là mức độ khó trên leetcode.com, thách thức của bài này là thuật toán phải được chạy với time O(log(m+n))


Nguyên văn bài toán và các ví dụ

Given two sorted arrays nums1 and nums2 of size m and n respectively, return the median of the two sorted arrays.

The overall run time complexity should be O(log (m+n)).

Example 1:
Input: nums1 = [1,3], nums2 = [2]
Output: 2.00000
Explanation: merged array = [1,2,3] and median is 2.

Example 2:
Input: nums1 = [1,2], nums2 = [3,4]
Output: 2.50000
Explanation: merged array = [1,2,3,4] and median is (2 + 3) / 2 = 2.5.

Example 3:
Input: nums1 = [0,0], nums2 = [0,0]
Output: 0.00000

Example 4:
Input: nums1 = [], nums2 = [1]
Output: 1.00000

Example 5:
Input: nums1 = [2], nums2 = []
Output: 2.00000

Trung vị là gì?

Khác với trung bình, trung vị là vị trí phần tử ở chính giữa của một dãy đã sắp xếp, vị trí đó cho biết giá trị điển hình của dãy đó. Nếu dãy có số lượng phần tử là lẻ, thì sẽ lấy phần tử ở giữa, ngược lại là chẵn, thì sẽ lấy trung bình của 2 phần tử ở giữa

Hướng tiếp cận 1

Hướng tiếp cận ban đầu, dễ nhất chính là sử dụng một dãy merge chứa toàn bộ phần tử của hai dãy, sau đó cài đặt thuật toán sắp xếp lại rồi tìm trung vị như công thức.

def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
\tnums = nums1 + nums2
    #sort in-place memory
    nums.sort() 
    length = len(nums)
        
\treturn (nums[int(length/2)-1] + nums[int(length/2)])/2 if length % 2 == 0 else nums[int(length/2)]

Lợi ích của code này là bạn không cần quan tâm đến việc sort() của code được chạy như thế nào, thực tế, python sử dụng thuật toán Timsort, giải quyết được tổng quát đa số trường hợp để có time and space complexity thấp nhất, bạn có thể tham khảo thêm tại:

https://en.wikipedia.org/wiki/Timsort

Timsort is a hybrid sorting algorithm, derived from merge sort and insertion sort, designed to perform well on many kinds of real-world data. It was invented by Tim Peters in 2002 for use in the Python programming language. The algorithm finds subsets of the data that are already ordered, and uses the subsets to sort the data more efficiently. This is done by merging an identified subset, called a run, with existing runs until certain criteria are fulfilled. Timsort has been Python’s standard sorting algorithm since version 2.3. It is now also used to sort arrays in Java SE 7, and on the Android platform.

Tuy vậy, việc bạn sử dụng sort() lại không thực sự tối ưu, cũng bởi vì code quá ngắn, tinh túy của bài toán bạn chưa thực sự chạm vào, thì không một interviewer nào chấp nhận code này làm câu trả lời cuối cùng

Ở đây chúng tôi không làm thế ;)

Hướng tiếp cận 2

Khi bạn được reject câu trả lời trên, thì bạn sẽ bảo “Ồ, vậy tôi thử tìm một giải thuật sắp xếp nào đó để được O(log(m+n)) xem sao”. Nếu bạn được nhớ lại, hoặc may mắn được interviewer nhắc bài “Cho tới hiện nay, không một chương trình tuyến tính nào có thể cài đặt thuật toán sort tốt hơn O((n+m)log(n+m))“. Nói như vậy không có nghĩa bạn lại đi cài đặt parallel nhé :)

Đúc kết lại, nếu bạn cố tìm thuật toán merge 2 dãy đã sắp xếp lại, thì nó không phải là câu trả lời interviewer muốn nghe

Nếu bạn có hứng thú với thuật toán merge này, bạn có thể tham khảo tại: https://www.geeksforgeeks.org/merge-two-sorted-arrays-o1-extra-space/

Hướng tiếp cận 3

Xem lại bài toán, xét 2 khía cạnh của đề bài:

  1. Hai dãy đã được sắp xếp
  2. Chỉ quan tâm số ở giữa để tìm trung vị

=> Cần thiết để phải sort lại toàn bộ khi mình chỉ cần 1, 2 phần tử ở giữa?

Ý tưởng chính

Xét 2 dãy A = [3, 6, 8, 9, 10]B = [1, 4, 12].

Dãy merge đã sắp xếp sẽ là C = [1, 3, 4, 6, 8, 9, 10, 12] => Dãy Median(C) = (6 + 8)/2 = 7

Nếu tôi cắt đôi dãy A và dãy A, 2 dãy bên trái gọi là nhóm X, 2 dãy bên phải gọi là nhóm Y, sao cho mỗi nhóm có số lượng bằng nhau hoặc chênh lệch một phần tử như sau

Nhóm X Nhóm Y
[3, 6, 8] [9, 10]
[1] [4, 12]

Nếu nhóm X hoàn toàn thuộc về bên trái, nhóm Y hoàn toàn thuộc về bên phải, hay nói cách khác, Max(X) <= Min(Y), thì bạn đã ở vị trí chính giữa của 2 dãy, hoàn toàn có thể tìm được giá trị trung vị mà không cần phải sắp xếp 2 dãy đó. Việc của mình chỉ là dịch (shift) trái hoặc dịch phải các phần tử để đáp ứng tiêu chí đó. Nếu nhóm X cho nhóm Y một phần tử ở dãy A, thì nhóm Y cũng phải cho X một phần tử ở dãy B tương ứng. Nếu bạn nhìn kỹ thì nó như theo một chiều của kim đồng hồ.

Tôi thử shift chiều kim đồng hồ, kết quả:

Nhóm X Nhóm Y
[3, 6] [8, 9, 10]
[1, 4] [12]

Nhận xét Max(X) = 6 <= Min(Y) = 8. Do tổng số phần tử là chẵn, nên (6+8)/2 = 7 => giá trị cần tìm

Có 2 ràng buộc để solution này khả thi:

  • Nếu dãy A hoặc B không được sắp xếp độc lập, bạn không thể shift
  • Bài toán chỉ để tìm trung vị

Đây là core của bài toán, kỹ thuật này còn gọi là Binary Search - tìm kiếm nhị phân, các mục tiếp theo là đi sâu vào nhiều khía cạnh phải xem xét.

Không còn khả năng shift

Nếu tôi thay 12 bằng 5, bài toán cần thiết phải shift thêm 1 lần, như sau:

Nhóm X Nhóm Y
[3, 6] [8, 9, 10]
[1, 4] [5]
Nhóm X Nhóm Y
[3] [6, 8, 9, 10]
[1, 4, 5] ?

=> [1, 3, 4, 5, 6, 8, 9, 10] => Chọn (5+6)/2 = 5.5
=> Uhmm, không quan trọng lắm, ? thay bằng Infinity thì sẽ không bị gì, tương tự cho các trường hợp còn lại, thì bên Y bị rút hết thì sẽ là dương vô cực, X thì là âm vô cực.

Trường hợp lệch nhau 1 đơn vị

Giả sử tôi thêm 11 vào nhóm Y của dãy A

Nhóm X Nhóm Y
[3] [6, 8, 9, 10, 11]
[1, 4, 5] ?

=> [1, 3, 4, 5, 6, 8, 9, 10, 11] => Chọn 6

=> Nếu lệch ở đâu thì lấy phần tử gần trung tâm nhất của bên đó

Có tồn tại lệch nhau nhiều hơn 1 đơn vị mà không thể shift?

Không. Giả sử:

Nhóm X Nhóm Y
[3] [6, 8, 9, 10, 13, 15]
[1, 4, 5] ?

=> Vi phạm việc chia đôi, chính xác ban đầu phải là:

Nhóm X Nhóm Y
[3, 6] [8, 9, 10, 13, 15]
[1, 4, 5] ?

Làm sao để biết shift chiều nào?

Nhóm X Nhóm Y
[3, 6 = O] [8 = P, 9, 10]
[1, 4 = N] [5 = M]

Rất dễ nhận biết, xét 4 số trung tâm 6, 4, 8, 5, tương ứng O, N, P, M

  • Theo chiều kim nếu O > M: 6 > 5. Mà N <= M, O <= P => P >= O > M >= N => P > N
  • Ngược chiều kim nếu N > P. Mà N <= M, O <= P => M >= N > P >= O => M > O

Như vậy, chỉ tồn tại một trong 2 điều kiện đó có thể xảy ra. Đối với dấu =, thuật toán có thể kết thúc ngay, vd:

Nhóm X Nhóm Y
[3, 5 = O] [5 = P, 9, 10]
[1, 5 = N] [5 = M]

=> Max(X) = 5 <= Minx(Y) = 5 => (5+5)/2 = 5

=> [1, 3, 5, 5, 5, 5, 9, 10] => Chọn 5

Dịch bao nhiêu đơn vị?

Để có được độ phức tạp O(log(m+n)), cần thiết phải search chia đôi liên tục thay vì dịch từng đơn vị 1. Nghĩa là, với độ dài k = m + n, mỗi vòng lặp sẽ thực hiện trên khoảng k = k/2, chọn nửa k bên nào sẽ do chiều shift quyết định.

Cài đặt trên Python

Thuật toán:

  • Chọn A là dãy bé, B là dãy lớn
  • Khai báo các var
  • len_a, len_b là size của dãy A, B
  • most_min, most_max để sử dụng khi không thể shift
  • size_x là size của nhóm X, mọi thao tác đều trên nhóm X
  • a_search_start a_search_end định nghĩa khoảng trên dãy A mà BinarySearch sẽ nhảy vào. Mặc định khởi tạo sẽ bằng với khoảng của dãy A
  • Trong vòng lặp,
  • x_len_a, x_len_b là độ dài dãy A, B trên nhóm X. Lần lặp đầu tiên x_len_a sẽ lấy khoảng nửa của dãy A, từ đó tìm ra x_len_b. Từ lần 2 trở đi, tùy vào chiều shift mà a_search_start/end quyết định x_len_a sẽ tăng hay giảm, dẫn đến x_len_b giảm hay tăng, tương ứng.
  • x_a, y_a, x_b, y_b là 4 vị trí trung tâm đang xét
  • Nếu Min(X) <= Min(Y), tương ứng với x_a <= y_b and x_b <= y_a, xét chẵn lẻ của tổng dãy và ra kết quả
  • Ngược lại có 2 trường hợp xảy ra:
  • Nếu dịch theo chiều kim, x_a > y_b, thì sẽ lùi a_search_end một khoảng sao cho mốc mới sẽ bằng với x_len_a trừ một đơn vị
  • Ngược chiều kim, thì sẽ tăng a_search_start một khoảng sao cho mốc mới sẽ bằng với x_len_a thêm một đơn vị

from typing import List

def findMedianSortedArrays(self, A: List[int], B: List[int]) -> float:
    if len(A) > len(B):
        A, B = B, A

    len_a, len_b = len(A), len(B)
    most_min, most_max = float("-inf"), float("inf")

    size_x = (len_a + len_b + 1) // 2
    a_search_start = 0
    a_search_end = len_a

    is_even = ((len_a + len_b) % 2) == 0
    while True:     
        x_len_a = (a_search_start + a_search_end) // 2
        x_len_b = size_x - x_len_a

        x_a = most_min if x_len_a == 0 else A[x_len_a - 1]
        y_a = most_max if x_len_a == len_a else A[x_len_a]
        x_b = most_min if x_len_b == 0 else B[x_len_b - 1]
        y_b = most_max if x_len_b == len_b else B[x_len_b]

        if x_a <= y_b and x_b <= y_a:
            if is_even:
                return (max(x_a, x_b) + min(y_a, y_b))/ 2

            return max(x_a, x_b)

        if x_a > y_b:
            a_search_end = x_len_a - 1
            continue

        if x_b > y_a:
            a_search_start = x_len_a + 1
            continue

    return 0

Bài viết trước:
Bài viết kế:
0%