diff --git a/docs/site/dml-language-reference.md b/docs/site/dml-language-reference.md index 264b3c6a2b1..6a26c4caaac 100644 --- a/docs/site/dml-language-reference.md +++ b/docs/site/dml-language-reference.md @@ -2068,10 +2068,12 @@ The following example uses transformapply() with the input matrix a **Table F5**: Frame processing built-in functions -Function | Description | Parameters | Example --------- | ----------- | ---------- | ------- -map() | It will execute the given lambda expression on a frame (cell, row or column wise). | Input: (X <frame>, y <String>, \[margin <int>\])
Output: <frame>.
X is a frame and
y is a String containing the lambda expression to be executed on frame X.
margin - how to apply the lambda expression (0 indicates each cell, 1 - rows, 2 - columns). Output matrix dimensions are always equal to the input. | [map](#map) -tokenize() | Transforms a frame to tokenized frame using specification. Tokenization is valid only for string columns. | Input:
target = <frame>
spec = <json specification>
Outputs: <matrix>, <frame> | [tokenize](#tokenize) +Function | Description | Parameters | Example +-------- |-----------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ------- +map() | It will execute the given lambda expression on a frame (cell, row or column wise). | Input: (X <frame>, y <String>, \[margin <int>\])
Output: <frame>.
X is a frame and
y is a String containing the lambda expression to be executed on frame X.
margin - how to apply the lambda expression (0 indicates each cell, 1 - rows, 2 - columns). Output matrix dimensions are always equal to the input. | [map](#map) +tokenize() | Transforms a frame to tokenized frame using specification. Tokenization is valid only for string columns. | Input:
target = <frame>
spec = <json specification>
Outputs: <matrix>, <frame> | [tokenize](#tokenize) +getNames() | Returns the column names of a frame as a single-row frame. | Input: X <frame>
Output: <frame> | N = getNames(X) +setNames() | Sets the column names of a frame from a single-row frame containing string values. | Input:
X = <frame>
N = <frame>
Output:<frame> | Y = setNames(X, N) #### map diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index f5719641df7..fa5da4cd91d 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -154,6 +154,7 @@ public enum Builtins { GARCH("garch", true), GAUSSIAN_CLASSIFIER("gaussianClassifier", true), GET_ACCURACY("getAccuracy", true), + GET_NAMES("getNames", false), GET_CATEGORICAL_MASK("getCategoricalMask", false), GLM("glm", true), GLM_PREDICT("glmPredict", true), @@ -310,6 +311,7 @@ public enum Builtins { SELVARTHRESH("selectByVarThresh", true), SEQ("seq", false), SES("ses", true), + SET_NAMES("setNames", false), SYMMETRICDIFFERENCE("symmetricDifference", true), SHAPEXPLAINER("shapExplainer", true), SHERLOCK("sherlock", true), diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 9a894dde13b..0f971b764c8 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -352,6 +352,7 @@ public enum Opcodes { MAPPM("map+*", InstructionType.Binary), MAPMINUSMULT("map-*", InstructionType.Binary), MAPDROPINVALIDLENGTH("mapdropInvalidLength", InstructionType.Binary), + SET_COLNAMES("set_colnames", InstructionType.Binary), MAPGT("map>", InstructionType.Binary), MAPGE("map>=", InstructionType.Binary), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index c2832aeb8cd..e7282cde00e 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -641,7 +641,8 @@ public enum OpOp2 { MINUS1_MULT(false), //1-X*Y GET_CATEGORICAL_MASK(false), // get transformation mask QUANTIZE_COMPRESS(false), //quantization-fused compression - UNION_DISTINCT(false); + UNION_DISTINCT(false), + SET_COLNAMES(false); private final boolean _validOuter; diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index ab0c7993b4e..301b4d2765a 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -1095,6 +1095,7 @@ else if( getAllExpr().length == 2 ) { //binary case TYPEOF: case DETECTSCHEMA: case COLNAMES: + case GET_NAMES: checkNumParameters(1); checkMatrixFrameParam(getFirstExpr()); output.setDataType(DataType.FRAME); @@ -1102,6 +1103,26 @@ else if( getAllExpr().length == 2 ) { //binary output.setBlocksize (id.getBlocksize()); output.setValueType(ValueType.STRING); break; + case SET_NAMES: + //check if we use 2 parameters (Frame on which nemas are set and vector for names) + checkNumParameters(2); + + // check if first paramters is a frame + checkMatrixFrameParam(getFirstExpr()); + + // check if second paramters is a vector 1xn Frame + checkMatrixFrameParam(getSecondExpr()); + + //output should be a frame + output.setDataType(DataType.FRAME); + + + checkMatrixFrameParam(getFirstExpr()); + output.setDataType(DataType.FRAME); + output.setDimensions(id.getDim1(), id.getDim2()); + output.setBlocksize (id.getBlocksize()); + output.setValueType(ValueType.STRING); + break; case CAST_AS_FRAME: // operation as.frame // overloaded to take either one argument or 2 where second is column names diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index e14cfd31388..0000cc20677 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2762,6 +2762,22 @@ else if ( in.length == 2 ) case TYPEOF: case DET: case DETECTSCHEMA: + case SET_NAMES: + currBuiltinOp = new BinaryOp( + target.getName(), + target.getDataType(), + target.getValueType(), + OpOp2.SET_COLNAMES, expr, expr2 + ); + break; + case GET_NAMES: + currBuiltinOp = new UnaryOp( + target.getName(), + target.getDataType(), + target.getValueType(), + OpOp1.COLNAMES, expr + ); + break; case COLNAMES: currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp1.valueOf(source.getOpCode().name()), expr); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index da3de02419d..031cf406d8d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -686,6 +686,8 @@ else if( opcode.equalsIgnoreCase(Opcodes.VALUESWAP.toString())) return new BinaryOperator(Builtin.getBuiltinFnObject("valueSwap")); else if( opcode.equalsIgnoreCase(Opcodes.FREPLICATE.toString())) return new BinaryOperator(Builtin.getBuiltinFnObject("freplicate")); + else if( opcode.equalsIgnoreCase(Opcodes.SET_COLNAMES.toString())) + return new BinaryOperator(Builtin.getBuiltinFnObject("set_colnames")); throw new RuntimeException("Unknown binary opcode " + opcode); } @@ -923,6 +925,9 @@ else if ( opcode.equalsIgnoreCase(Opcodes.DROPINVALIDLENGTH.toString()) || opcod return new BinaryOperator(Builtin.getBuiltinFnObject("dropInvalidLength")); else if ( opcode.equalsIgnoreCase(Opcodes.VALUESWAP.toString()) || opcode.equalsIgnoreCase("mapValueSwap") ) return new BinaryOperator(Builtin.getBuiltinFnObject("valueSwap")); + //TODO: Check what "|| opcode.equalsIgnoreCase("mapValueSwap"))" does + else if (opcode.equalsIgnoreCase(Opcodes.SET_COLNAMES.toString()) || opcode.equalsIgnoreCase("mapValueSwap")) + return new BinaryOperator(Builtin.getBuiltinFnObject("set_colnames")); throw new DMLRuntimeException("Unknown binary opcode " + opcode); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java index e9771b2e7fe..6d4564b7752 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java @@ -62,6 +62,28 @@ else if(getOpcode().equals(Opcodes.APPLYSCHEMA.toString())) { final int k = ((MultiThreadedOperator)_optr).getNumThreads(); final FrameBlock out = FrameLibApplySchema.applySchema(inBlock1, inBlock2, k); ec.setFrameOutput(output.getName(), out); + } + else if(getOpcode().equals(Opcodes.SET_COLNAMES.toString())) { + + FrameBlock in = ec.getFrameInput(input1.getName()); + FrameBlock names = ec.getFrameInput(input2.getName()); + + String[] colNames = new String[(int) names.getNumColumns()]; + for(int i = 0; i < colNames.length; i++){ + colNames[i] = names.get(0, i).toString(); + } + + FrameBlock out = new FrameBlock(in); + + out.setColumnNames(colNames); + + ec.setFrameOutput(output.getName(), out); + + ec.releaseFrameInput(input1.getName()); + + ec.releaseFrameInput(input2.getName()); + + } else { // Execute binary operations diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java index 107cab79d79..2dc56f513c5 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java @@ -52,6 +52,14 @@ else if(getOpcode().equals(Opcodes.COLNAMES.toString())) { ec.releaseFrameInput(input1.getName()); ec.setFrameOutput(output.getName(), retBlock); } + //TODO: Check if new OPcode handling has to be implemented + else if(getOpcode().equals(Opcodes.COLNAMES.toString())) { + FrameBlock inBlock = ec.getFrameInput(input1.getName()); + FrameBlock retBlock = inBlock.getColumnNamesAsFrame(); + ec.releaseFrameInput(input1.getName()); + ec.setFrameOutput(output.getName(), retBlock); + } + else throw new DMLScriptException("Opcode '" + getOpcode() + "' is not a valid UnaryFrameCPInstruction"); } diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java new file mode 100644 index 00000000000..f72e7734c48 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.frame; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; +import org.apache.sysds.runtime.io.FrameWriter; +import org.apache.sysds.runtime.io.FrameWriterFactory; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class FrameColNamesPropagationTest extends AutomatedTestBase { + private final static String TEST_NAME_CBIND = "ColNameCbindPropagation"; + private final static String TEST_NAME_RBIND = "ColNameRbindPropagation"; + private final static String TEST_NAME_SLICE = "ColNameSlicePropagation"; + private final static String TEST_DIR = "functions/frame/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FrameColumnNamesTest.class.getSimpleName() + "/"; + + @Parameterized.Parameter + public int _matrixDim; + + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList(new Object[][] { + {10}, + {100}, + {1000}, + }); + } + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME_CBIND, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_CBIND, new String[] {"B"})); + addTestConfiguration(TEST_NAME_RBIND, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_RBIND, new String[] {"B"})); + addTestConfiguration(TEST_NAME_SLICE, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_SLICE, new String[] {"B"})); + + } + + @Test + public void testPropagationCbindCP() { + runPropagationCbindTest(_matrixDim, ExecType.CP); + } + + @Test + public void testPropagationRbindCP() { + runPropagationRbindTest(_matrixDim, ExecType.CP); + } + + @Test + public void testPropagationSliceCP() { + runPropagationSliceTest(_matrixDim, ExecType.CP); + } + + + private String[] genColnames(int n, String prefix){ + String[] colName = new String[n]; + for(int i = 0; i < n; i++){ + colName[i] = prefix + i; + } + return colName; + } + + private void runPropagationCbindTest(Integer matrixDim, ExecType et) { + Types.ExecMode platformOld = setExecMode(et); + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + setOutputBuffering(true); + try { + + // generate an array of column names depending on the dimension of the frame block + String[] colNames1 = genColnames(matrixDim, "A"); + String[] colNames2 = genColnames(matrixDim, "B"); + + getAndLoadTestConfiguration(TEST_NAME_CBIND); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME_CBIND + ".dml"; + + + programArgs = new String[] {"-args", + input("X1"), String.valueOf(matrixDim), + String.valueOf(matrixDim), + input("X2"), + Integer.toString(matrixDim), + output("B")}; + + FrameWriter writer = FrameWriterFactory.createFrameWriter(FileFormat.CSV, + new FileFormatPropertiesCSV(true, ",", false)); + + + Types.ValueType[] schema1 = Collections.nCopies( + matrixDim, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + FrameBlock X1 = new FrameBlock(schema1); + X1.setColumnNames(colNames1); + double[][] data_X = getRandomMatrix(matrixDim, matrixDim, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(X1, data_X, schema1, matrixDim); + writer.writeFrameToHDFS(X1, input("X1"), matrixDim, matrixDim); + + + Types.ValueType[] schema2 = Collections.nCopies( + matrixDim, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + FrameBlock X2 = new FrameBlock(schema2); + X2.setColumnNames(colNames2); + double[][] data_X2 = getRandomMatrix(matrixDim, matrixDim, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(X2, data_X2, schema2, matrixDim); + writer.writeFrameToHDFS(X2, input("X2"), matrixDim, matrixDim); + + + runTest(true, false, null, -1); + + + FrameBlock out = readDMLFrameFromHDFS("B", FileFormat.BINARY); + + // create array of expected column names + String[] expected = new String[colNames1.length + colNames2.length]; + System.arraycopy(colNames1, 0, expected, 0, colNames1.length); + System.arraycopy(colNames2, 0, expected, colNames1.length, colNames2.length); + + // compare column names after operation with expected column names + for(int i = 0; i < expected.length; i++) { + Assert.assertEquals( + "Wrong colName at pos:" + i, + expected[i], + out.get(0, i).toString() + ); + } + + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + private void runPropagationRbindTest(Integer matrixDim, ExecType et) { + Types.ExecMode platformOld = setExecMode(et); + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + setOutputBuffering(true); + try { + + // generate an array of column names depending on the dimension of the frame block + String[] colNames1 = genColnames(matrixDim, "A"); + String[] colNames2 = genColnames(matrixDim, "B"); + + getAndLoadTestConfiguration(TEST_NAME_RBIND); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME_RBIND + ".dml"; + + + programArgs = new String[] {"-args", + input("X1"), String.valueOf(matrixDim), + String.valueOf(matrixDim), + input("X2"), + Integer.toString(matrixDim), + output("B")}; + + FrameWriter writer = FrameWriterFactory.createFrameWriter(FileFormat.CSV, + new FileFormatPropertiesCSV(true, ",", false)); + + + Types.ValueType[] schema1 = Collections.nCopies( + matrixDim, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + FrameBlock X1 = new FrameBlock(schema1); + X1.setColumnNames(colNames1); + double[][] data_X = getRandomMatrix(matrixDim, matrixDim, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(X1, data_X, schema1, matrixDim); + writer.writeFrameToHDFS(X1, input("X1"), matrixDim, matrixDim); + + + Types.ValueType[] schema2 = Collections.nCopies( + matrixDim, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + FrameBlock X2 = new FrameBlock(schema2); + X2.setColumnNames(colNames2); + double[][] data_X2 = getRandomMatrix(matrixDim, matrixDim, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(X2, data_X2, schema2, matrixDim); + writer.writeFrameToHDFS(X2, input("X2"), matrixDim, matrixDim); + + runTest(true, false, null, -1); + + FrameBlock out = readDMLFrameFromHDFS("B", FileFormat.BINARY); + + // expected are the column names from the first frame block + for(int i = 0; i < colNames1.length; i++) { + Assert.assertEquals( + "Wrong colName at pos:" + i, + colNames1[i], + out.get(0, i).toString() + ); + } + + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + private void runPropagationSliceTest(Integer matrixDim, ExecType et) { + Types.ExecMode platformOld = setExecMode(et); + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + setOutputBuffering(true); + try { + + // generate an array of column names depending on the dimension of the frame block + String[] colNames = genColnames(matrixDim, "A"); + + getAndLoadTestConfiguration(TEST_NAME_SLICE); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME_SLICE + ".dml"; + + + programArgs = new String[] {"-args", + input("X"), String.valueOf(matrixDim), + String.valueOf(matrixDim), + output("B")}; + + FrameWriter writer = FrameWriterFactory.createFrameWriter(FileFormat.CSV, + new FileFormatPropertiesCSV(true, ",", false)); + + + Types.ValueType[] schema = Collections.nCopies( + matrixDim, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + FrameBlock X1 = new FrameBlock(schema); + X1.setColumnNames(colNames); + double[][] data_X = getRandomMatrix(matrixDim, matrixDim, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(X1, data_X, schema, matrixDim); + writer.writeFrameToHDFS(X1, input("X"), matrixDim, matrixDim); + + runTest(true, false, null, -1); + + FrameBlock out = readDMLFrameFromHDFS("B", FileFormat.BINARY); + + String[] expected = Arrays.copyOfRange(colNames, 1, colNames.length-1); + + // expected are the sliced column names + for(int i = 0; i < expected.length; i++) { + Assert.assertEquals( + "Wrong colName at pos:" + i, + expected[i], + out.get(0, i).toString() + ); + } + + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + +} + diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java index d1ee4215e1a..a43302e6d1d 100644 --- a/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java +++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java @@ -44,6 +44,8 @@ @net.jcip.annotations.NotThreadSafe public class FrameColumnNamesTest extends AutomatedTestBase { private final static String TEST_NAME = "ColumnNames"; + private final static String TEST_NAME_GET = "GetNames"; + private final static String TEST_NAME_SET = "SetNames"; private final static String TEST_DIR = "functions/frame/"; private static final String TEST_CLASS_DIR = TEST_DIR + FrameColumnNamesTest.class.getSimpleName() + "/"; @@ -60,6 +62,9 @@ public static Collection data() { @Override public void setUp() { addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"})); + addTestConfiguration(TEST_NAME_GET, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_SET, new String[] {"B"})); + addTestConfiguration(TEST_NAME_SET, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_GET, new String[] {"B"})); + } @Test @@ -72,6 +77,107 @@ public void testDetectSchemaDoubleSpark() { runGetColNamesTest(_columnNames, ExecType.SPARK); } + @Test + public void testGetNamesCP() { + runGetNamesTest(_columnNames, ExecType.CP); + } + + @Test + public void testSetNamesCP() { + runSetNamesTest(_columnNames, ExecType.CP); + } + + private void runGetNamesTest(String[] columnNames, ExecType et) { + Types.ExecMode platformOld = setExecMode(et); + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + setOutputBuffering(true); + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME_GET + ".dml"; + programArgs = new String[] {"-args", input("A"), String.valueOf(_rows), + Integer.toString(columnNames.length), output("B")}; + + Types.ValueType[] schema = Collections.nCopies( + columnNames.length, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + FrameBlock frame1 = new FrameBlock(schema); + frame1.setColumnNames(columnNames); + FrameWriter writer = FrameWriterFactory.createFrameWriter(FileFormat.CSV, + new FileFormatPropertiesCSV(true, ",", false)); + + double[][] A = getRandomMatrix(_rows, schema.length, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(frame1, A, schema, _rows); + writer.writeFrameToHDFS(frame1, input("A"), _rows, schema.length); + + runTest(true, false, null, -1); + FrameBlock frame2 = readDMLFrameFromHDFS("B", FileFormat.BINARY); + + // verify output schema + for(int i = 0; i < schema.length; i++) { + Assert + .assertEquals("Wrong result: " + columnNames[i] + ".", columnNames[i], frame2.get(0, i).toString()); + } + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + private void runSetNamesTest(String[] columnNames, ExecType et) { + Types.ExecMode platformOld = setExecMode(et); + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + setOutputBuffering(true); + try { + getAndLoadTestConfiguration(TEST_NAME_SET); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME_SET + ".dml"; + programArgs = new String[] {"-args",input("X"),String.valueOf(_rows),Integer.toString(columnNames.length), + input("N"),output("B") + }; + + Types.ValueType[] schema = Collections.nCopies( + columnNames.length, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + + FrameBlock frame1 = new FrameBlock(schema); + FrameWriter writer = FrameWriterFactory.createFrameWriter(FileFormat.CSV, + new FileFormatPropertiesCSV(true, ",", false)); + + double[][] A = getRandomMatrix(_rows, schema.length, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(frame1, A, schema, _rows); + writer.writeFrameToHDFS(frame1, input("X"), _rows, schema.length); + + Types.ValueType[] nameSchema = Collections.nCopies( + columnNames.length, Types.ValueType.STRING).toArray(new Types.ValueType[0]); + + FrameBlock names = new FrameBlock(nameSchema); + names.ensureAllocatedColumns(1); + for(int i = 0; i < columnNames.length; i++) + names.set(0, i, columnNames[i]); + FrameWriter nameWriter = FrameWriterFactory.createFrameWriter(FileFormat.CSV, + new FileFormatPropertiesCSV(false, ",", false)); + System.out.println("N path = " + input("N")); + nameWriter.writeFrameToHDFS(names, input("N"), 1, columnNames.length); + + runTest(true, false, null, -1); + + FrameBlock frame2 = readDMLFrameFromHDFS("B", FileFormat.BINARY); + for(int i = 0; i < columnNames.length; i++) + Assert.assertEquals("Wrong result: " + columnNames[i] + ".", columnNames[i], frame2.get(0, i).toString()); + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + private void runGetColNamesTest(String[] columnNames, ExecType et) { Types.ExecMode platformOld = setExecMode(et); boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; diff --git a/src/test/scripts/functions/frame/ColNameCbindPropagation.dml b/src/test/scripts/functions/frame/ColNameCbindPropagation.dml new file mode 100644 index 00000000000..e46a4a7dbe8 --- /dev/null +++ b/src/test/scripts/functions/frame/ColNameCbindPropagation.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X1 = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE); +X2 = read($4, rows=$2, cols=$5, data_type="frame", format="csv", header=TRUE); +Y = cbind(X1, X2); +B = getNames(Y); +write(B, $6, format="binary"); \ No newline at end of file diff --git a/src/test/scripts/functions/frame/ColNameRbindPropagation.dml b/src/test/scripts/functions/frame/ColNameRbindPropagation.dml new file mode 100644 index 00000000000..e11892f3645 --- /dev/null +++ b/src/test/scripts/functions/frame/ColNameRbindPropagation.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X1 = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE); +X2 = read($4, rows=$2, cols=$5, data_type="frame", format="csv", header=TRUE); +Y = rbind(X1, X2); +B = getNames(Y); +write(B, $6, format="binary"); \ No newline at end of file diff --git a/src/test/scripts/functions/frame/ColNameSlicePropagation.dml b/src/test/scripts/functions/frame/ColNameSlicePropagation.dml new file mode 100644 index 00000000000..647e4f172d5 --- /dev/null +++ b/src/test/scripts/functions/frame/ColNameSlicePropagation.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE); +Y = X[,2:($3-1)]; +B = getNames(Y); +write(B, $4, format="binary"); \ No newline at end of file diff --git a/src/test/scripts/functions/frame/GetNames.dml b/src/test/scripts/functions/frame/GetNames.dml new file mode 100644 index 00000000000..70b8f22d8d9 --- /dev/null +++ b/src/test/scripts/functions/frame/GetNames.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE); +R = getNames(X); +write(R, $4, format="binary"); \ No newline at end of file diff --git a/src/test/scripts/functions/frame/SetNames.dml b/src/test/scripts/functions/frame/SetNames.dml new file mode 100644 index 00000000000..157a415babc --- /dev/null +++ b/src/test/scripts/functions/frame/SetNames.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE); +N = read($4, rows=1, cols=$3, data_type="frame", format="csv", header=FALSE); + +X2 = setNames(X, N) +B = getNames(X2) + +write(B, $5, format="binary"); \ No newline at end of file