from util import * from all2all import all2all import heapq def lpt(cx, npart): ''' from the list of (cx, gid) return a npart length list with each partition being a total_cx followed by a list of (cx, gid). ''' cx.sort(key=lambda x:x[0], reverse=True) # initialize a priority queue for fast determination of current # partition with least complexity. The priority queue always has # npart items in it. At this time we do not care which partition will # be associated with which rank so a partition on the heap is just # (totalcx, [list of (cx, gid)] h = [] for i in range(npart): heapq.heappush(h, (0.0, [])) #each cx item goes into the current least complex partition for c in cx: lp = heapq.heappop(h) # least partition lp[1].append(c) heapq.heappush(h, (lp[0]+c[0], lp[1])) parts = [heapq.heappop(h) for i in range(len(h))] return parts def statistics(parts): npart = len(parts) total_cx = 0 max_part_cx = 0 ncx = 0 max_cx = 0 for part in parts: ncx += len(part[1]) total_cx += part[0] if part[0] > max_part_cx: max_part_cx = part[0] for cx in part[1]: if cx[0] > max_cx: max_cx = cx[0] avg_part_cx =total_cx/npart loadbal = 1.0 if max_part_cx > 0.: loadbal = avg_part_cx/max_part_cx s = "loadbal=%g total_cx=%g npart=%d ncx=%d max_part_cx=%g max_cx=%g"%(loadbal,total_cx,npart,ncx,max_part_cx, max_cx) return s if __name__ == '__main__': from util import serialize, finish for cx in ([(i, i) for i in range(10)],[]): print len(cx), ' complexity items ', cx pinfo = lpt(cx, 3) print len(pinfo), ' lpt partitions ', pinfo print statistics(pinfo)