From 1cffaa23bb6f27ae471df33b49cb31ed617dba6e Mon Sep 17 00:00:00 2001 From: t99-i Date: Sun, 14 Jun 2026 17:36:14 +0200 Subject: [PATCH 1/5] WIP: add setName and getName first basics --- .../org/apache/sysds/common/Builtins.java | 2 ++ .../java/org/apache/sysds/common/Opcodes.java | 1 + .../java/org/apache/sysds/common/Types.java | 3 ++- .../parser/BuiltinFunctionExpression.java | 21 ++++++++++++++++++ .../apache/sysds/parser/DMLTranslator.java | 16 ++++++++++++++ .../instructions/InstructionUtils.java | 5 +++++ .../cp/BinaryFrameFrameCPInstruction.java | 22 +++++++++++++++++++ .../cp/UnaryFrameCPInstruction.java | 8 +++++++ 8 files changed, 77 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index e21c539d6d8..8c1e0690b0a 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -83,6 +83,8 @@ public enum Builtins { COLMEAN("colMeans", false), COLMIN("colMins", false), COLNAMES("colnames", false), + SET_NAMES("setNames", false), + GET_NAMES("getNames", false), COLPROD("colProds", false), COLSD("colSds", false), COLSUM("colSums", false), diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 1b0536416d6..4acd5a949fb 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -350,6 +350,7 @@ public enum Opcodes { MAPPM("map+*", InstructionType.Binary), MAPMINUSMULT("map-*", InstructionType.Binary), MAPDROPINVALIDLENGTH("mapdropInvalidLength", InstructionType.Binary), + SET_COLNAMES("set_colnames", InstructionType.Binary), MAPGT("map>", InstructionType.Binary), MAPGE("map>=", InstructionType.Binary), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 2e3543882d2..c3ec1982467 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -640,7 +640,8 @@ public enum OpOp2 { LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5) MINUS1_MULT(false), //1-X*Y QUANTIZE_COMPRESS(false), //quantization-fused compression - UNION_DISTINCT(false); + UNION_DISTINCT(false), + SET_COLNAMES(false); private final boolean _validOuter; diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 28f6949f722..a213faf1d55 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -1095,7 +1095,28 @@ else if( getAllExpr().length == 2 ) { //binary case TYPEOF: case DETECTSCHEMA: case COLNAMES: + case GET_NAMES: checkNumParameters(1); + checkMatrixFrameParam(getFirstExpr()); + output.setDataType(DataType.FRAME); + output.setDimensions(1, id.getDim2()); + output.setBlocksize (id.getBlocksize()); + output.setValueType(ValueType.STRING); + break; + case SET_NAMES: + //check if we use 2 parameters (Frame on which nemas are set and vector for names) + checkNumParameters(2); + + // check if first paramters is a frame + checkMatrixFrameParam(getFirstExpr()); + + // check if second paramters is a vector 1xn Frame + checkMatrixFrameParam(getSecondExpr()); + + //output should be a frame + output.setDataType(DataType.FRAME); + + checkMatrixFrameParam(getFirstExpr()); output.setDataType(DataType.FRAME); output.setDimensions(1, id.getDim2()); diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index c6e7188d7bc..294fa45a037 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2762,6 +2762,22 @@ else if ( in.length == 2 ) case TYPEOF: case DET: case DETECTSCHEMA: + case SET_NAMES: + currBuiltinOp = new BinaryOp( + target.getName(), + target.getDataType(), + target.getValueType(), + OpOp2.SET_COLNAMES, expr, expr2 + ); + break; + case GET_NAMES: + currBuiltinOp = new UnaryOp( + target.getName(), + target.getDataType(), + target.getValueType(), + OpOp1.COLNAMES, expr + ); + break; case COLNAMES: currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp1.valueOf(source.getOpCode().name()), expr); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index da3de02419d..031cf406d8d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -686,6 +686,8 @@ else if( opcode.equalsIgnoreCase(Opcodes.VALUESWAP.toString())) return new BinaryOperator(Builtin.getBuiltinFnObject("valueSwap")); else if( opcode.equalsIgnoreCase(Opcodes.FREPLICATE.toString())) return new BinaryOperator(Builtin.getBuiltinFnObject("freplicate")); + else if( opcode.equalsIgnoreCase(Opcodes.SET_COLNAMES.toString())) + return new BinaryOperator(Builtin.getBuiltinFnObject("set_colnames")); throw new RuntimeException("Unknown binary opcode " + opcode); } @@ -923,6 +925,9 @@ else if ( opcode.equalsIgnoreCase(Opcodes.DROPINVALIDLENGTH.toString()) || opcod return new BinaryOperator(Builtin.getBuiltinFnObject("dropInvalidLength")); else if ( opcode.equalsIgnoreCase(Opcodes.VALUESWAP.toString()) || opcode.equalsIgnoreCase("mapValueSwap") ) return new BinaryOperator(Builtin.getBuiltinFnObject("valueSwap")); + //TODO: Check what "|| opcode.equalsIgnoreCase("mapValueSwap"))" does + else if (opcode.equalsIgnoreCase(Opcodes.SET_COLNAMES.toString()) || opcode.equalsIgnoreCase("mapValueSwap")) + return new BinaryOperator(Builtin.getBuiltinFnObject("set_colnames")); throw new DMLRuntimeException("Unknown binary opcode " + opcode); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java index e9771b2e7fe..6d4564b7752 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java @@ -62,6 +62,28 @@ else if(getOpcode().equals(Opcodes.APPLYSCHEMA.toString())) { final int k = ((MultiThreadedOperator)_optr).getNumThreads(); final FrameBlock out = FrameLibApplySchema.applySchema(inBlock1, inBlock2, k); ec.setFrameOutput(output.getName(), out); + } + else if(getOpcode().equals(Opcodes.SET_COLNAMES.toString())) { + + FrameBlock in = ec.getFrameInput(input1.getName()); + FrameBlock names = ec.getFrameInput(input2.getName()); + + String[] colNames = new String[(int) names.getNumColumns()]; + for(int i = 0; i < colNames.length; i++){ + colNames[i] = names.get(0, i).toString(); + } + + FrameBlock out = new FrameBlock(in); + + out.setColumnNames(colNames); + + ec.setFrameOutput(output.getName(), out); + + ec.releaseFrameInput(input1.getName()); + + ec.releaseFrameInput(input2.getName()); + + } else { // Execute binary operations diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java index 107cab79d79..2dc56f513c5 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java @@ -52,6 +52,14 @@ else if(getOpcode().equals(Opcodes.COLNAMES.toString())) { ec.releaseFrameInput(input1.getName()); ec.setFrameOutput(output.getName(), retBlock); } + //TODO: Check if new OPcode handling has to be implemented + else if(getOpcode().equals(Opcodes.COLNAMES.toString())) { + FrameBlock inBlock = ec.getFrameInput(input1.getName()); + FrameBlock retBlock = inBlock.getColumnNamesAsFrame(); + ec.releaseFrameInput(input1.getName()); + ec.setFrameOutput(output.getName(), retBlock); + } + else throw new DMLScriptException("Opcode '" + getOpcode() + "' is not a valid UnaryFrameCPInstruction"); } From 94d7fadcd55e8644be7f7f362cff01b74529d8fd Mon Sep 17 00:00:00 2001 From: t99-i Date: Sat, 20 Jun 2026 14:05:58 +0200 Subject: [PATCH 2/5] WIP: - fix dim for SetNames - implemented tests for SetName and GetName --- .../parser/BuiltinFunctionExpression.java | 2 +- .../functions/frame/FrameColumnNamesTest.java | 106 ++++++++++++++++++ src/test/scripts/functions/frame/GetNames.dml | 24 ++++ src/test/scripts/functions/frame/SetNames.dml | 28 +++++ 4 files changed, 159 insertions(+), 1 deletion(-) create mode 100644 src/test/scripts/functions/frame/GetNames.dml create mode 100644 src/test/scripts/functions/frame/SetNames.dml diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index a213faf1d55..875ec0a0ace 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -1119,7 +1119,7 @@ else if( getAllExpr().length == 2 ) { //binary checkMatrixFrameParam(getFirstExpr()); output.setDataType(DataType.FRAME); - output.setDimensions(1, id.getDim2()); + output.setDimensions(id.getDim1(), id.getDim2()); output.setBlocksize (id.getBlocksize()); output.setValueType(ValueType.STRING); break; diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java index d1ee4215e1a..a43302e6d1d 100644 --- a/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java +++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java @@ -44,6 +44,8 @@ @net.jcip.annotations.NotThreadSafe public class FrameColumnNamesTest extends AutomatedTestBase { private final static String TEST_NAME = "ColumnNames"; + private final static String TEST_NAME_GET = "GetNames"; + private final static String TEST_NAME_SET = "SetNames"; private final static String TEST_DIR = "functions/frame/"; private static final String TEST_CLASS_DIR = TEST_DIR + FrameColumnNamesTest.class.getSimpleName() + "/"; @@ -60,6 +62,9 @@ public static Collection data() { @Override public void setUp() { addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"})); + addTestConfiguration(TEST_NAME_GET, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_SET, new String[] {"B"})); + addTestConfiguration(TEST_NAME_SET, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_GET, new String[] {"B"})); + } @Test @@ -72,6 +77,107 @@ public void testDetectSchemaDoubleSpark() { runGetColNamesTest(_columnNames, ExecType.SPARK); } + @Test + public void testGetNamesCP() { + runGetNamesTest(_columnNames, ExecType.CP); + } + + @Test + public void testSetNamesCP() { + runSetNamesTest(_columnNames, ExecType.CP); + } + + private void runGetNamesTest(String[] columnNames, ExecType et) { + Types.ExecMode platformOld = setExecMode(et); + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + setOutputBuffering(true); + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME_GET + ".dml"; + programArgs = new String[] {"-args", input("A"), String.valueOf(_rows), + Integer.toString(columnNames.length), output("B")}; + + Types.ValueType[] schema = Collections.nCopies( + columnNames.length, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + FrameBlock frame1 = new FrameBlock(schema); + frame1.setColumnNames(columnNames); + FrameWriter writer = FrameWriterFactory.createFrameWriter(FileFormat.CSV, + new FileFormatPropertiesCSV(true, ",", false)); + + double[][] A = getRandomMatrix(_rows, schema.length, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(frame1, A, schema, _rows); + writer.writeFrameToHDFS(frame1, input("A"), _rows, schema.length); + + runTest(true, false, null, -1); + FrameBlock frame2 = readDMLFrameFromHDFS("B", FileFormat.BINARY); + + // verify output schema + for(int i = 0; i < schema.length; i++) { + Assert + .assertEquals("Wrong result: " + columnNames[i] + ".", columnNames[i], frame2.get(0, i).toString()); + } + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + private void runSetNamesTest(String[] columnNames, ExecType et) { + Types.ExecMode platformOld = setExecMode(et); + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + setOutputBuffering(true); + try { + getAndLoadTestConfiguration(TEST_NAME_SET); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME_SET + ".dml"; + programArgs = new String[] {"-args",input("X"),String.valueOf(_rows),Integer.toString(columnNames.length), + input("N"),output("B") + }; + + Types.ValueType[] schema = Collections.nCopies( + columnNames.length, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + + FrameBlock frame1 = new FrameBlock(schema); + FrameWriter writer = FrameWriterFactory.createFrameWriter(FileFormat.CSV, + new FileFormatPropertiesCSV(true, ",", false)); + + double[][] A = getRandomMatrix(_rows, schema.length, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(frame1, A, schema, _rows); + writer.writeFrameToHDFS(frame1, input("X"), _rows, schema.length); + + Types.ValueType[] nameSchema = Collections.nCopies( + columnNames.length, Types.ValueType.STRING).toArray(new Types.ValueType[0]); + + FrameBlock names = new FrameBlock(nameSchema); + names.ensureAllocatedColumns(1); + for(int i = 0; i < columnNames.length; i++) + names.set(0, i, columnNames[i]); + FrameWriter nameWriter = FrameWriterFactory.createFrameWriter(FileFormat.CSV, + new FileFormatPropertiesCSV(false, ",", false)); + System.out.println("N path = " + input("N")); + nameWriter.writeFrameToHDFS(names, input("N"), 1, columnNames.length); + + runTest(true, false, null, -1); + + FrameBlock frame2 = readDMLFrameFromHDFS("B", FileFormat.BINARY); + for(int i = 0; i < columnNames.length; i++) + Assert.assertEquals("Wrong result: " + columnNames[i] + ".", columnNames[i], frame2.get(0, i).toString()); + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + private void runGetColNamesTest(String[] columnNames, ExecType et) { Types.ExecMode platformOld = setExecMode(et); boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; diff --git a/src/test/scripts/functions/frame/GetNames.dml b/src/test/scripts/functions/frame/GetNames.dml new file mode 100644 index 00000000000..70b8f22d8d9 --- /dev/null +++ b/src/test/scripts/functions/frame/GetNames.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE); +R = getNames(X); +write(R, $4, format="binary"); \ No newline at end of file diff --git a/src/test/scripts/functions/frame/SetNames.dml b/src/test/scripts/functions/frame/SetNames.dml new file mode 100644 index 00000000000..157a415babc --- /dev/null +++ b/src/test/scripts/functions/frame/SetNames.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE); +N = read($4, rows=1, cols=$3, data_type="frame", format="csv", header=FALSE); + +X2 = setNames(X, N) +B = getNames(X2) + +write(B, $5, format="binary"); \ No newline at end of file From 2a2bb6f42d512e26b91995ee352ddff0b8938f04 Mon Sep 17 00:00:00 2001 From: t99-i Date: Sat, 20 Jun 2026 20:43:56 +0200 Subject: [PATCH 3/5] WIP: - add a test for propagation of column names during cbind operations - test for other operations following --- .../frame/FrameColNamesPropagationTest.java | 154 ++++++++++++++++++ .../functions/frame/ColNamePropagation.dml | 5 + 2 files changed, 159 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java create mode 100644 src/test/scripts/functions/frame/ColNamePropagation.dml diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java new file mode 100644 index 00000000000..42c3ef1d73a --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.frame; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; +import org.apache.sysds.runtime.io.FrameWriter; +import org.apache.sysds.runtime.io.FrameWriterFactory; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class FrameColNamesPropagationTest extends AutomatedTestBase { + private final static String TEST_NAME = "ColNamePropagation"; + private final static String TEST_DIR = "functions/frame/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FrameColumnNamesTest.class.getSimpleName() + "/"; + + @Parameterized.Parameter + public int _matrixDim; + + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList(new Object[][] { + {10}, + {100}, + {1000}, + }); + } + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"})); + } + + @Test + public void testPropagationCbindCP() { + runPropagationCbindTest(_matrixDim, ExecType.CP); + } + + private String[] genColnames(int n, String prefix){ + String[] colName = new String[n]; + for(int i = 0; i < n; i++){ + colName[i] = prefix + i; + } + return colName; + } + + private void runPropagationCbindTest(Integer matrixDim, ExecType et) { + Types.ExecMode platformOld = setExecMode(et); + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + setOutputBuffering(true); + try { + + // generate an array of column names depending on the dimension of the frame block + String[] colNames1 = genColnames(matrixDim, "A"); + String[] colNames2 = genColnames(matrixDim, "B"); + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + + + programArgs = new String[] {"-args", + input("X1"), String.valueOf(matrixDim), + String.valueOf(matrixDim), + input("X2"), + Integer.toString(matrixDim), + output("B")}; + + FrameWriter writer = FrameWriterFactory.createFrameWriter(FileFormat.CSV, + new FileFormatPropertiesCSV(true, ",", false)); + + + Types.ValueType[] schema1 = Collections.nCopies( + matrixDim, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + FrameBlock X1 = new FrameBlock(schema1); + X1.setColumnNames(colNames1); + double[][] data_X = getRandomMatrix(matrixDim, matrixDim, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(X1, data_X, schema1, matrixDim); + writer.writeFrameToHDFS(X1, input("X1"), matrixDim, matrixDim); + + + Types.ValueType[] schema2 = Collections.nCopies( + matrixDim, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + FrameBlock X2 = new FrameBlock(schema2); + X2.setColumnNames(colNames2); + double[][] data_X2 = getRandomMatrix(matrixDim, matrixDim, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(X2, data_X2, schema2, matrixDim); + writer.writeFrameToHDFS(X2, input("X2"), matrixDim, matrixDim); + + + runTest(true, false, null, -1); + + + FrameBlock out = readDMLFrameFromHDFS("B", FileFormat.BINARY); + + // create array of expected column names + String[] expected = new String[colNames1.length + colNames2.length]; + System.arraycopy(colNames1, 0, expected, 0, colNames1.length); + System.arraycopy(colNames2, 0, expected, colNames1.length, colNames2.length); + + // compare column names after operation with expected column names + for(int i = 0; i < expected.length; i++) { + Assert.assertEquals( + "Wrong colName at pos:" + i, + expected[i], + out.get(0, i).toString() + ); + } + + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + +} + diff --git a/src/test/scripts/functions/frame/ColNamePropagation.dml b/src/test/scripts/functions/frame/ColNamePropagation.dml new file mode 100644 index 00000000000..f042e0206ac --- /dev/null +++ b/src/test/scripts/functions/frame/ColNamePropagation.dml @@ -0,0 +1,5 @@ +X1 = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE); +X2 = read($4, rows=$2, cols=$5, data_type="frame", format="csv", header=TRUE); +Y = cbind(X1, X2); +B = getNames(Y); +write(B, $6, format="binary"); \ No newline at end of file From cdb8d678653efcb1cd31f84b0b1f3e6108fd6490 Mon Sep 17 00:00:00 2001 From: t99-i Date: Sun, 21 Jun 2026 10:10:11 +0200 Subject: [PATCH 4/5] Add column name propagation tests for cbind, rbind and slice --- .../frame/FrameColNamesPropagationTest.java | 146 +++++++++++++++++- .../frame/ColNameCbindPropagation.dml | 26 ++++ .../functions/frame/ColNamePropagation.dml | 5 - .../frame/ColNameRbindPropagation.dml | 26 ++++ .../frame/ColNameSlicePropagation.dml | 25 +++ 5 files changed, 219 insertions(+), 9 deletions(-) create mode 100644 src/test/scripts/functions/frame/ColNameCbindPropagation.dml delete mode 100644 src/test/scripts/functions/frame/ColNamePropagation.dml create mode 100644 src/test/scripts/functions/frame/ColNameRbindPropagation.dml create mode 100644 src/test/scripts/functions/frame/ColNameSlicePropagation.dml diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java index 42c3ef1d73a..f72e7734c48 100644 --- a/src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java @@ -43,7 +43,9 @@ @RunWith(value = Parameterized.class) @net.jcip.annotations.NotThreadSafe public class FrameColNamesPropagationTest extends AutomatedTestBase { - private final static String TEST_NAME = "ColNamePropagation"; + private final static String TEST_NAME_CBIND = "ColNameCbindPropagation"; + private final static String TEST_NAME_RBIND = "ColNameRbindPropagation"; + private final static String TEST_NAME_SLICE = "ColNameSlicePropagation"; private final static String TEST_DIR = "functions/frame/"; private static final String TEST_CLASS_DIR = TEST_DIR + FrameColumnNamesTest.class.getSimpleName() + "/"; @@ -61,7 +63,10 @@ public static Collection data() { @Override public void setUp() { - addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"})); + addTestConfiguration(TEST_NAME_CBIND, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_CBIND, new String[] {"B"})); + addTestConfiguration(TEST_NAME_RBIND, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_RBIND, new String[] {"B"})); + addTestConfiguration(TEST_NAME_SLICE, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_SLICE, new String[] {"B"})); + } @Test @@ -69,6 +74,17 @@ public void testPropagationCbindCP() { runPropagationCbindTest(_matrixDim, ExecType.CP); } + @Test + public void testPropagationRbindCP() { + runPropagationRbindTest(_matrixDim, ExecType.CP); + } + + @Test + public void testPropagationSliceCP() { + runPropagationSliceTest(_matrixDim, ExecType.CP); + } + + private String[] genColnames(int n, String prefix){ String[] colName = new String[n]; for(int i = 0; i < n; i++){ @@ -87,9 +103,9 @@ private void runPropagationCbindTest(Integer matrixDim, ExecType et) { String[] colNames1 = genColnames(matrixDim, "A"); String[] colNames2 = genColnames(matrixDim, "B"); - getAndLoadTestConfiguration(TEST_NAME); + getAndLoadTestConfiguration(TEST_NAME_CBIND); String HOME = SCRIPT_DIR + TEST_DIR; - fullDMLScriptName = HOME + TEST_NAME + ".dml"; + fullDMLScriptName = HOME + TEST_NAME_CBIND + ".dml"; programArgs = new String[] {"-args", @@ -150,5 +166,127 @@ private void runPropagationCbindTest(Integer matrixDim, ExecType et) { } } + private void runPropagationRbindTest(Integer matrixDim, ExecType et) { + Types.ExecMode platformOld = setExecMode(et); + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + setOutputBuffering(true); + try { + + // generate an array of column names depending on the dimension of the frame block + String[] colNames1 = genColnames(matrixDim, "A"); + String[] colNames2 = genColnames(matrixDim, "B"); + + getAndLoadTestConfiguration(TEST_NAME_RBIND); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME_RBIND + ".dml"; + + + programArgs = new String[] {"-args", + input("X1"), String.valueOf(matrixDim), + String.valueOf(matrixDim), + input("X2"), + Integer.toString(matrixDim), + output("B")}; + + FrameWriter writer = FrameWriterFactory.createFrameWriter(FileFormat.CSV, + new FileFormatPropertiesCSV(true, ",", false)); + + + Types.ValueType[] schema1 = Collections.nCopies( + matrixDim, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + FrameBlock X1 = new FrameBlock(schema1); + X1.setColumnNames(colNames1); + double[][] data_X = getRandomMatrix(matrixDim, matrixDim, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(X1, data_X, schema1, matrixDim); + writer.writeFrameToHDFS(X1, input("X1"), matrixDim, matrixDim); + + + Types.ValueType[] schema2 = Collections.nCopies( + matrixDim, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + FrameBlock X2 = new FrameBlock(schema2); + X2.setColumnNames(colNames2); + double[][] data_X2 = getRandomMatrix(matrixDim, matrixDim, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(X2, data_X2, schema2, matrixDim); + writer.writeFrameToHDFS(X2, input("X2"), matrixDim, matrixDim); + + runTest(true, false, null, -1); + + FrameBlock out = readDMLFrameFromHDFS("B", FileFormat.BINARY); + + // expected are the column names from the first frame block + for(int i = 0; i < colNames1.length; i++) { + Assert.assertEquals( + "Wrong colName at pos:" + i, + colNames1[i], + out.get(0, i).toString() + ); + } + + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + private void runPropagationSliceTest(Integer matrixDim, ExecType et) { + Types.ExecMode platformOld = setExecMode(et); + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + setOutputBuffering(true); + try { + + // generate an array of column names depending on the dimension of the frame block + String[] colNames = genColnames(matrixDim, "A"); + + getAndLoadTestConfiguration(TEST_NAME_SLICE); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME_SLICE + ".dml"; + + + programArgs = new String[] {"-args", + input("X"), String.valueOf(matrixDim), + String.valueOf(matrixDim), + output("B")}; + + FrameWriter writer = FrameWriterFactory.createFrameWriter(FileFormat.CSV, + new FileFormatPropertiesCSV(true, ",", false)); + + + Types.ValueType[] schema = Collections.nCopies( + matrixDim, Types.ValueType.FP64).toArray(new Types.ValueType[0]); + FrameBlock X1 = new FrameBlock(schema); + X1.setColumnNames(colNames); + double[][] data_X = getRandomMatrix(matrixDim, matrixDim, Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123); + TestUtils.initFrameData(X1, data_X, schema, matrixDim); + writer.writeFrameToHDFS(X1, input("X"), matrixDim, matrixDim); + + runTest(true, false, null, -1); + + FrameBlock out = readDMLFrameFromHDFS("B", FileFormat.BINARY); + + String[] expected = Arrays.copyOfRange(colNames, 1, colNames.length-1); + + // expected are the sliced column names + for(int i = 0; i < expected.length; i++) { + Assert.assertEquals( + "Wrong colName at pos:" + i, + expected[i], + out.get(0, i).toString() + ); + } + + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + } diff --git a/src/test/scripts/functions/frame/ColNameCbindPropagation.dml b/src/test/scripts/functions/frame/ColNameCbindPropagation.dml new file mode 100644 index 00000000000..e46a4a7dbe8 --- /dev/null +++ b/src/test/scripts/functions/frame/ColNameCbindPropagation.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X1 = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE); +X2 = read($4, rows=$2, cols=$5, data_type="frame", format="csv", header=TRUE); +Y = cbind(X1, X2); +B = getNames(Y); +write(B, $6, format="binary"); \ No newline at end of file diff --git a/src/test/scripts/functions/frame/ColNamePropagation.dml b/src/test/scripts/functions/frame/ColNamePropagation.dml deleted file mode 100644 index f042e0206ac..00000000000 --- a/src/test/scripts/functions/frame/ColNamePropagation.dml +++ /dev/null @@ -1,5 +0,0 @@ -X1 = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE); -X2 = read($4, rows=$2, cols=$5, data_type="frame", format="csv", header=TRUE); -Y = cbind(X1, X2); -B = getNames(Y); -write(B, $6, format="binary"); \ No newline at end of file diff --git a/src/test/scripts/functions/frame/ColNameRbindPropagation.dml b/src/test/scripts/functions/frame/ColNameRbindPropagation.dml new file mode 100644 index 00000000000..e11892f3645 --- /dev/null +++ b/src/test/scripts/functions/frame/ColNameRbindPropagation.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X1 = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE); +X2 = read($4, rows=$2, cols=$5, data_type="frame", format="csv", header=TRUE); +Y = rbind(X1, X2); +B = getNames(Y); +write(B, $6, format="binary"); \ No newline at end of file diff --git a/src/test/scripts/functions/frame/ColNameSlicePropagation.dml b/src/test/scripts/functions/frame/ColNameSlicePropagation.dml new file mode 100644 index 00000000000..647e4f172d5 --- /dev/null +++ b/src/test/scripts/functions/frame/ColNameSlicePropagation.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE); +Y = X[,2:($3-1)]; +B = getNames(Y); +write(B, $4, format="binary"); \ No newline at end of file From 2aa0e6f97f002aff735dae23a60913b13ead5faa Mon Sep 17 00:00:00 2001 From: t99-i Date: Sun, 21 Jun 2026 11:38:15 +0200 Subject: [PATCH 5/5] [SYSTEMDS-3857] Set/GetNames on Data Frames This patch adds the language references for the newly implemented getName and setName function. The order in Builtins.java was fixed to be alphabetical again --- docs/site/dml-language-reference.md | 10 ++++++---- src/main/java/org/apache/sysds/common/Builtins.java | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/site/dml-language-reference.md b/docs/site/dml-language-reference.md index 264b3c6a2b1..6a26c4caaac 100644 --- a/docs/site/dml-language-reference.md +++ b/docs/site/dml-language-reference.md @@ -2068,10 +2068,12 @@ The following example uses transformapply() with the input matrix a **Table F5**: Frame processing built-in functions -Function | Description | Parameters | Example --------- | ----------- | ---------- | ------- -map() | It will execute the given lambda expression on a frame (cell, row or column wise). | Input: (X <frame>, y <String>, \[margin <int>\])
Output: <frame>.
X is a frame and
y is a String containing the lambda expression to be executed on frame X.
margin - how to apply the lambda expression (0 indicates each cell, 1 - rows, 2 - columns). Output matrix dimensions are always equal to the input. | [map](#map) -tokenize() | Transforms a frame to tokenized frame using specification. Tokenization is valid only for string columns. | Input:
target = <frame>
spec = <json specification>
Outputs: <matrix>, <frame> | [tokenize](#tokenize) +Function | Description | Parameters | Example +-------- |-----------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ------- +map() | It will execute the given lambda expression on a frame (cell, row or column wise). | Input: (X <frame>, y <String>, \[margin <int>\])
Output: <frame>.
X is a frame and
y is a String containing the lambda expression to be executed on frame X.
margin - how to apply the lambda expression (0 indicates each cell, 1 - rows, 2 - columns). Output matrix dimensions are always equal to the input. | [map](#map) +tokenize() | Transforms a frame to tokenized frame using specification. Tokenization is valid only for string columns. | Input:
target = <frame>
spec = <json specification>
Outputs: <matrix>, <frame> | [tokenize](#tokenize) +getNames() | Returns the column names of a frame as a single-row frame. | Input: X <frame>
Output: <frame> | N = getNames(X) +setNames() | Sets the column names of a frame from a single-row frame containing string values. | Input:
X = <frame>
N = <frame>
Output:<frame> | Y = setNames(X, N) #### map diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 8c1e0690b0a..8811e912fd0 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -83,8 +83,6 @@ public enum Builtins { COLMEAN("colMeans", false), COLMIN("colMins", false), COLNAMES("colnames", false), - SET_NAMES("setNames", false), - GET_NAMES("getNames", false), COLPROD("colProds", false), COLSD("colSds", false), COLSUM("colSums", false), @@ -156,6 +154,7 @@ public enum Builtins { GARCH("garch", true), GAUSSIAN_CLASSIFIER("gaussianClassifier", true), GET_ACCURACY("getAccuracy", true), + GET_NAMES("getNames", false), GLM("glm", true), GLM_PREDICT("glmPredict", true), GLOVE("glove", true), @@ -311,6 +310,7 @@ public enum Builtins { SELVARTHRESH("selectByVarThresh", true), SEQ("seq", false), SES("ses", true), + SET_NAMES("setNames", false), SYMMETRICDIFFERENCE("symmetricDifference", true), SHAPEXPLAINER("shapExplainer", true), SHERLOCK("sherlock", true),