#
# ---------------------------------------------
# Bitonic v0.5 functionality
#

function exchange(localid,  remoteid)
    if verbose
        println("Exchange local data from $localid with partner $remoteid")
    end
    nothing # We have all data here ;)
end

function minmax(data, localid, remoteid, keepsmall)
    # Keep min-max on local data
    temp = copy(data[localid+1, :])
    if (keepsmall)
        view(data, localid+1, :)  .= min.(temp, data[remoteid+1, :])
        view(data, remoteid+1, :) .= max.(temp, data[remoteid+1, :])
    else
        view(data, localid+1, :)  .= max.(temp, data[remoteid+1, :])
        view(data, remoteid+1, :) .= min.(temp, data[remoteid+1, :])
    end
end



function sort_network!(data, n, depth)
    nodes = 0:n-1
    for step = depth-1:-1:0
        partnerid = nodes .⊻ (1 << step)
        direction = (nodes .& (1 << depth)) .== 0 .& (nodes .< partnerid)
        keepsmall = ((nodes .< partnerid) .& direction) .| ((nodes .> partnerid) .& .!direction)
        if verbose
            println("depth: $depth | step: $step | partner: $partnerid | keepsmall: $keepsmall")
        end
        # exchange with partner and keep small or large (run all MPI nodes)
        for i in 0:n-1
            if (i < partnerid[i+1])
                exchange(i, partnerid[i+1])
                minmax(data, i, partnerid[i+1], keepsmall[i+1])
            end
        end
    end
end



"""
    distbitonic!(p, data)

distributed bitonic sort v1 using elbow merge locally except for the first step
p:    The number of processes
data: (p, N/p) array
"""
function distbitonic!(p, data) 

    q = Int(log2(p))    # CPU order
  
    pid = 0:p-1
    ascending = mod.(pid,2) .== 0
    if verbose
        println("ascending: $ascending")
    end
    # local full sort here (run all MPI nodes)
    for i in 1:p
        sort!(view(data, i, :), rev = !ascending[i])
    end
    for depth = 1:q
        sort_network!(data, p, depth)
        ascending = (pid .& (1 << depth)) .== 0
        if verbose
            println("ascending: $ascending")
        end
        # local elbowmerge here (run all MPI nodes)
        for i in 1:p
            sort!(view(data, i, :), rev = !ascending[i])
        end
    end
  
    nothing
end

#
# Homework setup
# ---------------------------------------------
#
p::Int8 = 3  # The order of number of "processors"
q::Int8 = 8  # The data size order (power of 2) of each "processor"
verbose = false;


# Run Script
# ---------------------------------------------
P::Int = 2^p
Q::Int = 2^q
N::Int = 2^(q+p)

println("Distributed bitonic (v1) test")
println("p: $p -> Number of processors: $P")
println("q: $q -> Data length for each node: $Q, Total: $(P*Q)")

println("Create an $P x $Q array")
Data = rand(Int8, P, Q)

println("Sort array with $P (MPI) nodes")
@time distbitonic!(P, Data)

# Test
if issorted(vec(permutedims(Data)))
    println("Test: Passed")
else
    println("Test: Failed")
end