|
- #
- # ---------------------------------------------
- # Bitonic v0.5 functionality
- #
-
- function exchange(localid, remoteid)
- # if verbose
- # println("Exchange local data from $localid with partner $remoteid")
- # end
- nothing # We have all data here ;)
- end
-
- function minmax(data, localid, remoteid, keepsmall)
- # Keep min-max on local data
- temp = copy(data[localid+1, :])
- if (keepsmall)
- view(data, localid+1, :) .= min.(temp, data[remoteid+1, :])
- view(data, remoteid+1, :) .= max.(temp, data[remoteid+1, :])
- else
- view(data, localid+1, :) .= max.(temp, data[remoteid+1, :])
- view(data, remoteid+1, :) .= min.(temp, data[remoteid+1, :])
- end
- end
-
-
- function is_bitonic(arr)
- n = length(arr)
- if n <= 2
- return true # Any sequence of length <= 2 is bitonic
- end
-
- # State for state machine. 1: inc, -1: dec, 0: z-state
- state = 0
- inc_count = 0
- dec_count = 0
- ret = false
-
- for i in 1:n-1
- # Find the first order
- if state == 0
- if arr[i] > arr[i+1]
- state = -1
- dec_count += 1
- elseif arr[i] < arr[i+1]
- state = 1
- inc_count += 1
- end
- elseif state == -1 # decreasing
- if arr[i] < arr[i + 1]
- state = 1
- inc_count += 1
- end
- elseif state == 1 # increasing
- if arr[i] > arr[i+1]
- state = -1
- dec_count += 1
- end
- end
- end
-
- if inc_count <= 1 && dec_count <= 1
- ret = true # Sequence is bitonic
- elseif inc_count == 2 && dec_count == 1
- ret = (arr[1] >= arr[n])
- elseif inc_count == 1 && dec_count == 2
- ret = (arr[1] <= arr[n])
- end
-
- ret
- end
-
- function is_sort(arr)
- # State for state machine. 1: inc, -1: dec, 0: z-state
- state = 0
- inc_count = 0
- dec_count = 0
-
- for i in 1:length(arr)-1
- # Find the first order
- if state == 0
- if arr[i] > arr[i+1]
- state = -1
- dec_count += 1
- elseif arr[i] < arr[i+1]
- state = 1
- inc_count += 1
- end
- elseif state == -1 # decreasing
- if arr[i] < arr[i + 1]
- state = 1
- inc_count += 1
- end
- elseif state == 1 # increasing
- if arr[i] > arr[i+1]
- state = -1
- dec_count += 1
- end
- end
- end
-
- ret = ((inc_count + dec_count) == 1) ? state : 0
- ret
- end
-
- function sort_network!(data, n, depth)
- nodes = 0:n-1
- bitonicFlag = zeros(Int8, size(data, 1))
- sortFlag = zeros(Int8, size(data, 1))
- for step = depth-1:-1:0
- partnerid = nodes .⊻ (1 << step)
- direction = (nodes .& (1 << depth)) .== 0 .& (nodes .< partnerid)
- keepsmall = ((nodes .< partnerid) .& direction) .| ((nodes .> partnerid) .& .!direction)
- if verbose
- println("depth: $depth | step: $step | partner: $partnerid | keepsmall: $keepsmall")
- end
- # exchange with partner and keep small or large (run all MPI nodes)
- for i in 0:n-1
- if (i < partnerid[i+1])
- exchange(i, partnerid[i+1])
- minmax(data, i, partnerid[i+1], keepsmall[i+1])
- end
- end
- if verbose
- for i in 1:size(data, 1)
- bitonicFlag[i] = is_bitonic(data[i, :])
- sortFlag[i] = is_sort(data[i, :])
- end
- println("depth: $depth | step: $step | bitonicFlag: $bitonicFlag | sorfFlag: $sortFlag")
- end
- end
- end
-
-
-
- """
- distbitonic!(p, data)
-
- distributed bitonic sort v1 using elbow merge locally except for the first step
- p: The number of processes
- data: (p, N/p) array
- """
- function distbitonic!(p, data)
-
- q = Int(log2(p)) # CPU order
-
- pid = 0:p-1
- ascending = mod.(pid,2) .== 0
- if verbose
- println("ascending: $ascending")
- end
- # local full sort here (run all MPI nodes)
- for i in 1:p
- sort!(view(data, i, :), rev = !ascending[i])
- end
- for depth = 1:q
- sort_network!(data, p, depth)
- ascending = (pid .& (1 << depth)) .== 0
- if verbose
- println("ascending: $ascending")
- end
- # local elbowmerge here (run all MPI nodes)
- for i in 1:p
- sort!(view(data, i, :), rev = !ascending[i])
- end
- end
-
- nothing
- end
-
- #
- # Homework setup
- # ---------------------------------------------
- #
- p::Int8 = 3 # The order of number of "processors"
- q::Int8 = 8 # The data size order (power of 2) of each "processor"
- verbose = false;
-
-
- # Run Script
- # ---------------------------------------------
- P::Int = 2^p
- Q::Int = 2^q
- N::Int = 2^(q+p)
-
- println("Distributed bitonic (v1) test")
- println("p: $p -> Number of processors: $P")
- println("q: $q -> Data length for each node: $Q, Total: $(P*Q)")
-
- println("Create an $P x $Q array")
- Data = rand(Int8, P, Q)
-
- println("Sort array with $P (MPI) nodes")
- @time distbitonic!(P, Data)
-
- # Test
- if issorted(vec(permutedims(Data)))
- println("Test: Passed")
- else
- println("Test: Failed")
- end
|