diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java index 81b3bb7b38d..1b74a3ae56f 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java @@ -22,6 +22,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.commons.math3.random.Well1024a; import org.apache.sysds.common.Opcodes; import org.apache.sysds.common.Types; import org.apache.sysds.hops.DataGenOp; @@ -177,6 +178,7 @@ else if(method == Types.OpOpDG.SEQ) { @Override public void processInstruction(ExecutionContext ec) { final OOCStream qOut = createWritableStream(); + ec.getMatrixObject(output).setStreamHandle(qOut); // process specific datagen operator if (method == Types.OpOpDG.RAND) { @@ -188,9 +190,6 @@ public void processInstruction(ExecutionContext ec) { long lcols = ec.getScalarInput(cols).getLongValue(); checkValidDimensions(lrows, lcols); - if (!pdf.equalsIgnoreCase("uniform") || minValue != maxValue) - throw new NotImplementedException(); // TODO modified version of rng as in LibMatrixDatagen to handle blocks independently - OOCStream qIn = createWritableStream(); int nrb = (int)((lrows-1) / blen)+1; int ncb = (int)((lcols-1) / blen)+1; @@ -210,10 +209,37 @@ public void processInstruction(ExecutionContext ec) { return; } + if(sparsity == 1.0 && minValue == maxValue) { + mapOOC(qIn, qOut, idx -> { + long rlen = Math.min(blen, lrows - (idx.getRowIndex()-1) * blen); + long clen = Math.min(blen, lcols - (idx.getColumnIndex()-1) * blen); + return new IndexedMatrixValue(idx, new MatrixBlock((int)rlen, (int)clen, minValue)); + }); + return; + } + + Well1024a bigrand = LibMatrixDatagen.setupSeedsForRand(lSeed); + int nb = nrb * ncb; + long[] seeds = new long[nb]; + for(int i = 0; i < nb; i++) seeds[i] = bigrand.nextLong(); + mapOOC(qIn, qOut, idx -> { long rlen = Math.min(blen, lrows - (idx.getRowIndex()-1) * blen); long clen = Math.min(blen, lcols - (idx.getColumnIndex()-1) * blen); - MatrixBlock mout = MatrixBlock.randOperations(getGenerator(rlen, clen), lSeed); + + int r = (int) idx.getRowIndex()-1; + int c = (int) idx.getColumnIndex()-1; + long bSeed = seeds[r*ncb+c]; + + final long estnnz = ((minValue==0.0 && maxValue==0.0) ? 0 : (long)(sparsity * rlen * clen)); + boolean lsparse = MatrixBlock.evalSparseFormatInMemory(rlen, clen, estnnz); + + MatrixBlock mout = new MatrixBlock(); + mout.reset((int) rlen, (int) clen, lsparse, estnnz); + mout.allocateBlock(); + + LibMatrixDatagen.genRandomNumbers(false, 0, 1, 0, 1, mout, getGenerator(rlen, clen), bSeed, null); + mout.recomputeNonZeros(); return new IndexedMatrixValue(idx, mout); }); } @@ -263,8 +289,6 @@ else if(method == Types.OpOpDG.SEQ) { } else throw new NotImplementedException(); - - ec.getMatrixObject(output).setStreamHandle(qOut); } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java index 58fb8e4fa8d..1edd439e40a 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java @@ -460,7 +460,7 @@ private static long[] sliceSeedsForCP(long[] seeds, int rl, int ru, int cl, int return lseeds; } - private static void genRandomNumbers(boolean invokedFromCP, int rl, int ru, int cl, int cu, MatrixBlock out, RandomMatrixGenerator rgen, long bSeed, long[] seeds) { + public static void genRandomNumbers(boolean invokedFromCP, int rl, int ru, int cl, int cu, MatrixBlock out, RandomMatrixGenerator rgen, long bSeed, long[] seeds) { int rows = rgen._rows; int cols = rgen._cols; int blen = rgen._blocksize; diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/RandTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/RandTest.java index 40430aa49f9..398e8d105bb 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/RandTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/RandTest.java @@ -53,11 +53,10 @@ public void setUp() { addTestConfiguration(TEST_NAME_2, config2); } - // Actual rand operation not yet supported - /*@Test + @Test public void testRand() { runRandTest(TEST_NAME_1); - }*/ + } @Test public void testConstInit() { diff --git a/src/test/scripts/functions/ooc/Rand1.dml b/src/test/scripts/functions/ooc/Rand1.dml index 2861f294620..6b08a84d0e3 100644 --- a/src/test/scripts/functions/ooc/Rand1.dml +++ b/src/test/scripts/functions/ooc/Rand1.dml @@ -19,6 +19,6 @@ # #------------------------------------------------------------- -res = rand(rows=1500, cols=1200, min=-1, max=1); +res = rand(rows=1500, cols=1200, min=-1, max=1, seed=42); write(res, $2, format="binary");