from util import * from lpt import lpt, statistics from all2all import all2all def load_bal(cx, npart): ''' cx is list of (complexity, gid) pairs on this process return: an LPT balanced list of gids that should belong to this process ''' elapse = h.startsw() #send to rank 0 r = all2all({0:cx}) # make a list of all the (cx, gid) s = {} if rank == 0: c = [] for i in r.values(): c += i del r #distribute by LPT parts = lpt(c, npart) print statistics(parts) for i,p in enumerate(parts): s.update({i : p[1]}) else: del r #send each partition to the proper rank local = all2all(s) del s if rank == 0: print "load_bal time %g" % (h.startsw()-elapse) return local[0] if __name__ == '__main__': from util import serialize, finish if True: cx = [(10*rank+i, 10*rank+i) for i in range(1,5)] print cx cx = load_bal(cx, nhost) for r in serialize(): print 'rank %d '%rank, cx if nhost > 0: finish()