diff --git a/docs/site/dml-language-reference.md b/docs/site/dml-language-reference.md
index 264b3c6a2b1..6a26c4caaac 100644
--- a/docs/site/dml-language-reference.md
+++ b/docs/site/dml-language-reference.md
@@ -2068,10 +2068,12 @@ The following example uses transformapply() with the input matrix a
**Table F5**: Frame processing built-in functions
-Function | Description | Parameters | Example
--------- | ----------- | ---------- | -------
-map() | It will execute the given lambda expression on a frame (cell, row or column wise). | Input: (X <frame>, y <String>, \[margin <int>\]) Output: <frame>. X is a frame and y is a String containing the lambda expression to be executed on frame X. margin - how to apply the lambda expression (0 indicates each cell, 1 - rows, 2 - columns). Output matrix dimensions are always equal to the input. | [map](#map)
-tokenize() | Transforms a frame to tokenized frame using specification. Tokenization is valid only for string columns. | Input: target = <frame> spec = <json specification> Outputs: <matrix>, <frame> | [tokenize](#tokenize)
+Function | Description | Parameters | Example
+-------- |-----------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------
+map() | It will execute the given lambda expression on a frame (cell, row or column wise). | Input: (X <frame>, y <String>, \[margin <int>\]) Output: <frame>. X is a frame and y is a String containing the lambda expression to be executed on frame X. margin - how to apply the lambda expression (0 indicates each cell, 1 - rows, 2 - columns). Output matrix dimensions are always equal to the input. | [map](#map)
+tokenize() | Transforms a frame to tokenized frame using specification. Tokenization is valid only for string columns. | Input: target = <frame> spec = <json specification> Outputs: <matrix>, <frame> | [tokenize](#tokenize)
+getNames() | Returns the column names of a frame as a single-row frame. | Input: X <frame> Output: <frame> | N = getNames(X)
+setNames() | Sets the column names of a frame from a single-row frame containing string values. | Input: X = <frame> N = <frame> Output:<frame> | Y = setNames(X, N)
#### map
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index f5719641df7..fa5da4cd91d 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -154,6 +154,7 @@ public enum Builtins {
GARCH("garch", true),
GAUSSIAN_CLASSIFIER("gaussianClassifier", true),
GET_ACCURACY("getAccuracy", true),
+ GET_NAMES("getNames", false),
GET_CATEGORICAL_MASK("getCategoricalMask", false),
GLM("glm", true),
GLM_PREDICT("glmPredict", true),
@@ -310,6 +311,7 @@ public enum Builtins {
SELVARTHRESH("selectByVarThresh", true),
SEQ("seq", false),
SES("ses", true),
+ SET_NAMES("setNames", false),
SYMMETRICDIFFERENCE("symmetricDifference", true),
SHAPEXPLAINER("shapExplainer", true),
SHERLOCK("sherlock", true),
diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java
index 9a894dde13b..0f971b764c8 100644
--- a/src/main/java/org/apache/sysds/common/Opcodes.java
+++ b/src/main/java/org/apache/sysds/common/Opcodes.java
@@ -352,6 +352,7 @@ public enum Opcodes {
MAPPM("map+*", InstructionType.Binary),
MAPMINUSMULT("map-*", InstructionType.Binary),
MAPDROPINVALIDLENGTH("mapdropInvalidLength", InstructionType.Binary),
+ SET_COLNAMES("set_colnames", InstructionType.Binary),
MAPGT("map>", InstructionType.Binary),
MAPGE("map>=", InstructionType.Binary),
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index c2832aeb8cd..e7282cde00e 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -641,7 +641,8 @@ public enum OpOp2 {
MINUS1_MULT(false), //1-X*Y
GET_CATEGORICAL_MASK(false), // get transformation mask
QUANTIZE_COMPRESS(false), //quantization-fused compression
- UNION_DISTINCT(false);
+ UNION_DISTINCT(false),
+ SET_COLNAMES(false);
private final boolean _validOuter;
diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index ab0c7993b4e..301b4d2765a 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -1095,6 +1095,7 @@ else if( getAllExpr().length == 2 ) { //binary
case TYPEOF:
case DETECTSCHEMA:
case COLNAMES:
+ case GET_NAMES:
checkNumParameters(1);
checkMatrixFrameParam(getFirstExpr());
output.setDataType(DataType.FRAME);
@@ -1102,6 +1103,26 @@ else if( getAllExpr().length == 2 ) { //binary
output.setBlocksize (id.getBlocksize());
output.setValueType(ValueType.STRING);
break;
+ case SET_NAMES:
+ //check if we use 2 parameters (Frame on which nemas are set and vector for names)
+ checkNumParameters(2);
+
+ // check if first paramters is a frame
+ checkMatrixFrameParam(getFirstExpr());
+
+ // check if second paramters is a vector 1xn Frame
+ checkMatrixFrameParam(getSecondExpr());
+
+ //output should be a frame
+ output.setDataType(DataType.FRAME);
+
+
+ checkMatrixFrameParam(getFirstExpr());
+ output.setDataType(DataType.FRAME);
+ output.setDimensions(id.getDim1(), id.getDim2());
+ output.setBlocksize (id.getBlocksize());
+ output.setValueType(ValueType.STRING);
+ break;
case CAST_AS_FRAME:
// operation as.frame
// overloaded to take either one argument or 2 where second is column names
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index e14cfd31388..0000cc20677 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2762,6 +2762,22 @@ else if ( in.length == 2 )
case TYPEOF:
case DET:
case DETECTSCHEMA:
+ case SET_NAMES:
+ currBuiltinOp = new BinaryOp(
+ target.getName(),
+ target.getDataType(),
+ target.getValueType(),
+ OpOp2.SET_COLNAMES, expr, expr2
+ );
+ break;
+ case GET_NAMES:
+ currBuiltinOp = new UnaryOp(
+ target.getName(),
+ target.getDataType(),
+ target.getValueType(),
+ OpOp1.COLNAMES, expr
+ );
+ break;
case COLNAMES:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp1.valueOf(source.getOpCode().name()), expr);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index da3de02419d..031cf406d8d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -686,6 +686,8 @@ else if( opcode.equalsIgnoreCase(Opcodes.VALUESWAP.toString()))
return new BinaryOperator(Builtin.getBuiltinFnObject("valueSwap"));
else if( opcode.equalsIgnoreCase(Opcodes.FREPLICATE.toString()))
return new BinaryOperator(Builtin.getBuiltinFnObject("freplicate"));
+ else if( opcode.equalsIgnoreCase(Opcodes.SET_COLNAMES.toString()))
+ return new BinaryOperator(Builtin.getBuiltinFnObject("set_colnames"));
throw new RuntimeException("Unknown binary opcode " + opcode);
}
@@ -923,6 +925,9 @@ else if ( opcode.equalsIgnoreCase(Opcodes.DROPINVALIDLENGTH.toString()) || opcod
return new BinaryOperator(Builtin.getBuiltinFnObject("dropInvalidLength"));
else if ( opcode.equalsIgnoreCase(Opcodes.VALUESWAP.toString()) || opcode.equalsIgnoreCase("mapValueSwap") )
return new BinaryOperator(Builtin.getBuiltinFnObject("valueSwap"));
+ //TODO: Check what "|| opcode.equalsIgnoreCase("mapValueSwap"))" does
+ else if (opcode.equalsIgnoreCase(Opcodes.SET_COLNAMES.toString()) || opcode.equalsIgnoreCase("mapValueSwap"))
+ return new BinaryOperator(Builtin.getBuiltinFnObject("set_colnames"));
throw new DMLRuntimeException("Unknown binary opcode " + opcode);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
index e9771b2e7fe..6d4564b7752 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
@@ -62,6 +62,28 @@ else if(getOpcode().equals(Opcodes.APPLYSCHEMA.toString())) {
final int k = ((MultiThreadedOperator)_optr).getNumThreads();
final FrameBlock out = FrameLibApplySchema.applySchema(inBlock1, inBlock2, k);
ec.setFrameOutput(output.getName(), out);
+ }
+ else if(getOpcode().equals(Opcodes.SET_COLNAMES.toString())) {
+
+ FrameBlock in = ec.getFrameInput(input1.getName());
+ FrameBlock names = ec.getFrameInput(input2.getName());
+
+ String[] colNames = new String[(int) names.getNumColumns()];
+ for(int i = 0; i < colNames.length; i++){
+ colNames[i] = names.get(0, i).toString();
+ }
+
+ FrameBlock out = new FrameBlock(in);
+
+ out.setColumnNames(colNames);
+
+ ec.setFrameOutput(output.getName(), out);
+
+ ec.releaseFrameInput(input1.getName());
+
+ ec.releaseFrameInput(input2.getName());
+
+
}
else {
// Execute binary operations
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java
index 107cab79d79..2dc56f513c5 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java
@@ -52,6 +52,14 @@ else if(getOpcode().equals(Opcodes.COLNAMES.toString())) {
ec.releaseFrameInput(input1.getName());
ec.setFrameOutput(output.getName(), retBlock);
}
+ //TODO: Check if new OPcode handling has to be implemented
+ else if(getOpcode().equals(Opcodes.COLNAMES.toString())) {
+ FrameBlock inBlock = ec.getFrameInput(input1.getName());
+ FrameBlock retBlock = inBlock.getColumnNamesAsFrame();
+ ec.releaseFrameInput(input1.getName());
+ ec.setFrameOutput(output.getName(), retBlock);
+ }
+
else
throw new DMLScriptException("Opcode '" + getOpcode() + "' is not a valid UnaryFrameCPInstruction");
}
diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java
new file mode 100644
index 00000000000..f72e7734c48
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.frame;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
+import org.apache.sysds.runtime.io.FrameWriter;
+import org.apache.sysds.runtime.io.FrameWriterFactory;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FrameColNamesPropagationTest extends AutomatedTestBase {
+ private final static String TEST_NAME_CBIND = "ColNameCbindPropagation";
+ private final static String TEST_NAME_RBIND = "ColNameRbindPropagation";
+ private final static String TEST_NAME_SLICE = "ColNameSlicePropagation";
+ private final static String TEST_DIR = "functions/frame/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FrameColumnNamesTest.class.getSimpleName() + "/";
+
+ @Parameterized.Parameter
+ public int _matrixDim;
+
+ @Parameterized.Parameters
+ public static Collection