From ec0538bf43e4c34d0923b652871a0649e18342ed Mon Sep 17 00:00:00 2001 From: Biranavan Date: Sun, 21 Jun 2026 18:05:04 +0200 Subject: [PATCH] [SYSTEMDS-3946] Enable sending of large (>2GiB) FederatedRequests and Responses Federated transfers previously failed for payloads above 2GiB because the single Netty frame size is bounded by a 32-bit length field, capping any request or response at Integer.MAX_VALUE bytes. This patch adds a streaming chunked codec that splits a large payload into bounded frames on the sender and reassembles them on the receiver, so the on-wire size is no longer limited by a single frame. A format detector and format encoder select the chunked path only when the payload exceeds the frame limit, leaving the existing small-message path unchanged to avoid added overhead for the common case. Adds FederatedMaxPayloadTest to exercise the boundary around the former 2GiB cap. --- .../federated/FederatedChunkDecoder.java | 176 +++++++++++++++ .../federated/FederatedChunkEncoder.java | 208 ++++++++++++++++++ .../federated/FederatedChunkProtocol.java | 52 +++++ .../federated/FederatedData.java | 54 ++--- .../federated/FederatedFormatDetector.java | 45 ++++ .../federated/FederatedFormatEncoder.java | 77 +++++++ .../federated/FederatedWorker.java | 12 +- .../network/FederatedMaxPayloadTest.java | 81 +++++++ 8 files changed, 667 insertions(+), 38 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedChunkDecoder.java create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedChunkEncoder.java create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedChunkProtocol.java create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedFormatDetector.java create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedFormatEncoder.java create mode 100644 src/test/java/org/apache/sysds/test/functions/federated/network/FederatedMaxPayloadTest.java diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedChunkDecoder.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedChunkDecoder.java new file mode 100644 index 00000000000..51e0002e54b --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedChunkDecoder.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.federated; + +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectStreamClass; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; + +import org.apache.sysds.runtime.util.CommonThreadPool; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageDecoder; + +public class FederatedChunkDecoder extends MessageToMessageDecoder { + private static final Object END_OF_STREAM = new Object(); + // stop reading at QUEUE_DEPTH, resume at half: gap avoids autoRead thrash + private static final int LOW_WATERMARK = FederatedChunkProtocol.QUEUE_DEPTH / 2; + + private final BlockingQueue _payloads = new LinkedBlockingQueue<>(); + private boolean _started; + private volatile boolean _throttled; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf frame, List out) { + startReader(ctx); + byte type = frame.readByte(); + int len = frame.readInt(); + switch(type) { + case FederatedChunkProtocol.TYPE_DATA: + _payloads.add(readBytes(frame, len)); + break; + case FederatedChunkProtocol.TYPE_END: + _payloads.add(END_OF_STREAM); + break; + case FederatedChunkProtocol.TYPE_ERROR: + _payloads.add(new IOException(frame.toString(frame.readerIndex(), len, StandardCharsets.UTF_8))); + break; + } + if(_payloads.size() >= FederatedChunkProtocol.QUEUE_DEPTH) { + _throttled = true; + ctx.channel().config().setAutoRead(false); + } + } + + private void startReader(ChannelHandlerContext ctx) { + if(_started) + return; + _started = true; + CommonThreadPool.getDynamicPool().execute(() -> runDeserializer(ctx)); + } + + private void runDeserializer(ChannelHandlerContext ctx) { + try(ObjectInputStream ois = objectInputStream(new PayloadInputStream(this, ctx))) { + Object msg = ois.readObject(); + ctx.channel().eventLoop().execute(() -> ctx.fireChannelRead(msg)); + } + catch(Throwable t) { + ctx.channel().eventLoop().execute(() -> ctx.fireExceptionCaught(t)); + } + } + + private Object nextPayload() throws InterruptedException { + return _payloads.take(); + } + + private void resumeReadingIfDrained(ChannelHandlerContext ctx) { + if(_throttled && _payloads.size() <= LOW_WATERMARK) { + _throttled = false; + ctx.channel().eventLoop().execute(() -> ctx.channel().config().setAutoRead(true)); + } + } + + private static ObjectInputStream objectInputStream(InputStream in) throws IOException { + return new ObjectInputStream(in) { + @Override + protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { + try { + return Class.forName(desc.getName(), false, ClassLoader.getSystemClassLoader()); + } + catch(ClassNotFoundException e) { + return super.resolveClass(desc); + } + } + }; + } + + private static byte[] readBytes(ByteBuf frame, int len) { + byte[] bytes = new byte[len]; + frame.readBytes(bytes); + return bytes; + } + + private static final class PayloadInputStream extends InputStream { + private static final byte[] EMPTY = new byte[0]; + + private final FederatedChunkDecoder _decoder; + private final ChannelHandlerContext _ctx; + private byte[] _current = EMPTY; + private int _pos; + private boolean _eof; + + PayloadInputStream(FederatedChunkDecoder decoder, ChannelHandlerContext ctx) { + _decoder = decoder; + _ctx = ctx; + } + + @Override + public int read() throws IOException { + if(!ensureCurrent()) + return -1; + return _current[_pos++] & 0xff; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if(!ensureCurrent()) + return -1; + int n = Math.min(len, _current.length - _pos); + System.arraycopy(_current, _pos, b, off, n); + _pos += n; + return n; + } + + private boolean ensureCurrent() throws IOException { + while(_pos == _current.length) { + if(_eof) + return false; + Object next = take(); + if(next == END_OF_STREAM) { + _eof = true; + return false; + } + if(next instanceof Throwable) + throw new IOException((Throwable) next); + _current = (byte[]) next; + _pos = 0; + } + return true; + } + + private Object take() throws IOException { + try { + Object next = _decoder.nextPayload(); + _decoder.resumeReadingIfDrained(_ctx); + return next; + } + catch(InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException(e); + } + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedChunkEncoder.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedChunkEncoder.java new file mode 100644 index 00000000000..e752c5614a8 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedChunkEncoder.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.federated; + +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; + +import org.apache.sysds.runtime.util.CommonThreadPool; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.handler.stream.ChunkedInput; +import io.netty.handler.stream.ChunkedWriteHandler; + +public class FederatedChunkEncoder extends ChannelOutboundHandlerAdapter { + private final int _chunkSize; + + public FederatedChunkEncoder() { + this(FederatedChunkProtocol.DEFAULT_CHUNK_SIZE); + } + + public FederatedChunkEncoder(int chunkSize) { + _chunkSize = chunkSize; + } + + static ChunkedInput chunkedInput(Serializable msg, int chunkSize, ByteBufAllocator alloc, + ChunkedWriteHandler writer) { + return new SerializedChunks(msg, chunkSize, alloc, writer); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if(msg instanceof Serializable) + ctx.write(new SerializedChunks((Serializable) msg, _chunkSize, ctx.alloc(), + ctx.pipeline().get(ChunkedWriteHandler.class)), promise); + else + ctx.write(msg, promise); + } + + private static final class SerializedChunks implements ChunkedInput { + private final BlockingQueue _frames = new ArrayBlockingQueue<>(FederatedChunkProtocol.QUEUE_DEPTH); + private final ByteBufAllocator _alloc; + private final ChunkedWriteHandler _writer; + private volatile boolean _closed; + private boolean _done; + + SerializedChunks(Serializable msg, int chunkSize, ByteBufAllocator alloc, ChunkedWriteHandler writer) { + _alloc = alloc; + _writer = writer; + CommonThreadPool.getDynamicPool().execute(() -> produceFrames(msg, chunkSize)); + } + + private void produceFrames(Serializable msg, int chunkSize) { + try(FrameOutputStream out = new FrameOutputStream(this, _alloc, chunkSize); + ObjectOutputStream oos = new ObjectOutputStream(out)) { + oos.writeObject(msg); + oos.flush(); + out.flushFrame(); + enqueueControlFrame(controlFrame(FederatedChunkProtocol.TYPE_END)); + } + catch(Throwable t) { + enqueueControlFrame(errorFrame(t)); + } + } + + private ByteBuf controlFrame(byte type) { + return _alloc.buffer(FederatedChunkProtocol.HEADER_LEN).writeByte(type).writeInt(0); + } + + private ByteBuf errorFrame(Throwable t) { + byte[] cause = String.valueOf(t).getBytes(StandardCharsets.UTF_8); + return _alloc.buffer(FederatedChunkProtocol.HEADER_LEN + cause.length) + .writeByte(FederatedChunkProtocol.TYPE_ERROR).writeInt(cause.length).writeBytes(cause); + } + + void enqueueFrame(ByteBuf frame) throws InterruptedException { + if(_closed) { + frame.release(); + return; + } + _frames.put(frame); + _writer.resumeTransfer(); + } + + private void enqueueControlFrame(ByteBuf frame) { + try { + enqueueFrame(frame); + } + catch(InterruptedException e) { + frame.release(); + Thread.currentThread().interrupt(); + } + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) { + if(_done) + return null; + ByteBuf frame = _frames.poll(); + if(frame == null) + return null; + _done = frame.getByte(frame.readerIndex()) != FederatedChunkProtocol.TYPE_DATA; + return frame; + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) { + return readChunk(ctx.alloc()); + } + + @Override + public boolean isEndOfInput() { + return _done; + } + + @Override + public long length() { + return -1; + } + + @Override + public long progress() { + return 0; + } + + @Override + public void close() { + _closed = true; + ByteBuf frame; + while((frame = _frames.poll()) != null) + frame.release(); + } + } + + private static final class FrameOutputStream extends OutputStream { + private final SerializedChunks _sink; + private final ByteBufAllocator _alloc; + private final byte[] _buffer; + private int _len; + + FrameOutputStream(SerializedChunks sink, ByteBufAllocator alloc, int chunkSize) { + _sink = sink; + _alloc = alloc; + _buffer = new byte[chunkSize]; + } + + @Override + public void write(int b) throws IOException { + _buffer[_len++] = (byte) b; + if(_len == _buffer.length) + flushFrame(); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + while(len > 0) { + int n = Math.min(len, _buffer.length - _len); + System.arraycopy(b, off, _buffer, _len, n); + _len += n; + off += n; + len -= n; + if(_len == _buffer.length) + flushFrame(); + } + } + + void flushFrame() throws IOException { + if(_len == 0) + return; + ByteBuf frame = _alloc.buffer(FederatedChunkProtocol.HEADER_LEN + _len) + .writeByte(FederatedChunkProtocol.TYPE_DATA).writeInt(_len).writeBytes(_buffer, 0, _len); + _len = 0; + try { + _sink.enqueueFrame(frame); + } + catch(InterruptedException e) { + frame.release(); + Thread.currentThread().interrupt(); + throw new IOException(e); + } + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedChunkProtocol.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedChunkProtocol.java new file mode 100644 index 00000000000..dc8e7662e46 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedChunkProtocol.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.federated; + +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; + +final class FederatedChunkProtocol { + static final byte TYPE_DATA = 0; + static final byte TYPE_END = 1; + static final byte TYPE_ERROR = 2; + + static final byte MARKER_LEGACY = 0; + static final byte MARKER_CHUNKED = 1; + static final long STREAM_THRESHOLD = 1536L << 20; // ~1.5 GB: route below this through the legacy object codec + + static final int HEADER_LEN = 5; + static final int DEFAULT_CHUNK_SIZE = 1 << 20; // 1 MB payload per frame + static final int QUEUE_DEPTH = 16; + + static final int LENGTH_FIELD_OFFSET = 1; + static final int LENGTH_FIELD_LENGTH = 4; + static final int LENGTH_ADJUSTMENT = 0; + static final int INITIAL_BYTES_TO_STRIP = 0; + + static int maxFrameLength(int chunkSize) { + return chunkSize + HEADER_LEN; + } + + static LengthFieldBasedFrameDecoder newFrameDecoder() { + return new LengthFieldBasedFrameDecoder(maxFrameLength(DEFAULT_CHUNK_SIZE), + LENGTH_FIELD_OFFSET, LENGTH_FIELD_LENGTH, LENGTH_ADJUSTMENT, INITIAL_BYTES_TO_STRIP); + } + + private FederatedChunkProtocol() {} +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java index 98572e2ddd0..4fc5a6c685b 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java @@ -22,7 +22,6 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; -import java.io.Serializable; import java.net.ConnectException; import java.net.InetSocketAddress; import java.util.ArrayList; @@ -37,9 +36,11 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; +import io.netty.channel.ChannelFutureListener; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; @@ -60,11 +61,11 @@ import org.apache.sysds.runtime.controlprogram.paramserv.NetworkTrafficCounter; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.conf.DMLConfig; -import io.netty.buffer.ByteBuf; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.serialization.ObjectEncoder; +import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.util.concurrent.Promise; @@ -204,6 +205,7 @@ public synchronized static Future executeFederatedOperation(I createWorkGroup(); b.group(workerGroup); b.channel(NioSocketChannel.class); + b.option(ChannelOption.ALLOW_HALF_CLOSURE, true); final DataRequestHandler handler = new DataRequestHandler(); // Client Netty @@ -212,7 +214,12 @@ public synchronized static Future executeFederatedOperation(I ChannelFuture f = b.connect(address).sync(); Promise promise = f.channel().eventLoop().newPromise(); handler.setPromise(promise); - f.channel().writeAndFlush(request); + f.channel().writeAndFlush(request).addListener((ChannelFutureListener) future -> { + if (!future.isSuccess()) { + LOG.error("Federated network write failed: " + future.cause().getMessage()); + promise.setFailure(future.cause()); + } + }); return handler.getProm(); } @@ -255,9 +262,11 @@ protected void initChannel(SocketChannel ch) throws Exception { cp.addLast(new ReadTimeoutHandler(timeout)); compressionStrategy.ifPresent(strategy -> cp.addLast(strategy.left)); - cp.addLast(FederationUtils.decoder()); + cp.addLast(new FederatedFormatDetector()); compressionStrategy.ifPresent(strategy -> cp.addLast(strategy.right)); - cp.addLast(new FederatedRequestEncoder()); + cp.addLast(new ChunkedWriteHandler()); + cp.addLast(new ObjectEncoder()); + cp.addLast(new FederatedFormatEncoder()); cp.addLast(handler); } }; @@ -318,6 +327,15 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { ctx.close(); } + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + // Fail (rather than leave hanging) any request whose connection closed before its response + // was delivered, so a waiting caller gets an exception instead of blocking until timeout. + if(_prom != null && !_prom.isDone()) + _prom.tryFailure(new IOException("Channel closed before federated response was received")); + super.channelInactive(ctx); + } + public Promise getProm() { return _prom; } @@ -334,32 +352,6 @@ public String toString() { return sb.toString(); } - public static class FederatedRequestEncoder extends ObjectEncoder { - @Override - protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, Serializable msg, boolean preferDirect) - throws Exception { - int initCapacity = 256; // default initial capacity - if(msg instanceof FederatedRequest[]) { - initCapacity = 0; - try { - for(FederatedRequest fr : (FederatedRequest[]) msg) { - int frSize = Math.toIntExact(fr.estimateSerializationBufferSize()); - if(Integer.MAX_VALUE - initCapacity < frSize) // summed sizes exceed integer limits - throw new ArithmeticException("Overflow."); - initCapacity += frSize; - } - } - catch(ArithmeticException ae) { // size of federated request exceeds integer limits - initCapacity = Integer.MAX_VALUE; - } - } - if(preferDirect) - return ctx.alloc().ioBuffer(initCapacity); - else - return ctx.alloc().heapBuffer(initCapacity); - } - } - /** * Requests privacy constraints from the federated worker * diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedFormatDetector.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedFormatDetector.java new file mode 100644 index 00000000000..4c0369a2bb8 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedFormatDetector.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.federated; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.ByteToMessageDecoder; + +public final class FederatedFormatDetector extends ByteToMessageDecoder { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + if(in.readableBytes() < 1) + return; + byte marker = in.readByte(); + ChannelPipeline cp = ctx.pipeline(); + if(marker == FederatedChunkProtocol.MARKER_CHUNKED) { + cp.addAfter(ctx.name(), "FederatedFrameDecoder", FederatedChunkProtocol.newFrameDecoder()); + cp.addAfter("FederatedFrameDecoder", "FederatedChunkDecoder", new FederatedChunkDecoder()); + } + else { + cp.addAfter(ctx.name(), "FederatedObjectDecoder", FederationUtils.decoder()); + } + cp.remove(this); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedFormatEncoder.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedFormatEncoder.java new file mode 100644 index 00000000000..41e78e3ddd1 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedFormatEncoder.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.federated; + +import java.io.Serializable; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.handler.stream.ChunkedWriteHandler; + +public class FederatedFormatEncoder extends ChannelOutboundHandlerAdapter { + private final int _chunkSize; + private final long _streamThreshold; + + public FederatedFormatEncoder() { + this(FederatedChunkProtocol.DEFAULT_CHUNK_SIZE, FederatedChunkProtocol.STREAM_THRESHOLD); + } + + public FederatedFormatEncoder(int chunkSize, long streamThreshold) { + _chunkSize = chunkSize; + _streamThreshold = streamThreshold; + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if(!(msg instanceof Serializable)) { + ctx.write(msg, promise); + return; + } + if(estimateSize(msg) >= _streamThreshold) { + ctx.write(markerBuffer(ctx, FederatedChunkProtocol.MARKER_CHUNKED), ctx.voidPromise()); + ctx.write(FederatedChunkEncoder.chunkedInput((Serializable) msg, _chunkSize, ctx.alloc(), + ctx.pipeline().get(ChunkedWriteHandler.class)), promise); + } + else { + ctx.write(markerBuffer(ctx, FederatedChunkProtocol.MARKER_LEGACY), ctx.voidPromise()); + ctx.write(msg, promise); + } + } + + private static ByteBuf markerBuffer(ChannelHandlerContext ctx, byte type) { + return ctx.alloc().buffer(1).writeByte(type); + } + + private static long estimateSize(Object msg) { + if(msg instanceof FederatedResponse) + return ((FederatedResponse) msg).estimateSerializationBufferSize(); + if(msg instanceof FederatedRequest) + return ((FederatedRequest) msg).estimateSerializationBufferSize(); + if(msg instanceof FederatedRequest[]) { + long size = 0; + for(FederatedRequest request : (FederatedRequest[]) msg) + size += request.estimateSerializationBufferSize(); + return size; + } + return 0; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java index fc8989053bc..38374eb567a 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java @@ -59,9 +59,8 @@ import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; -import io.netty.handler.codec.serialization.ClassResolvers; -import io.netty.handler.codec.serialization.ObjectDecoder; import io.netty.handler.codec.serialization.ObjectEncoder; +import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.SelfSignedCertificate; @@ -219,14 +218,13 @@ public void initChannel(SocketChannel ch) { cp.addLast("CompressionDecodingStartStatistics", new CompressionDecoderStartStatisticsHandler()); compressionStrategy.ifPresent(strategy -> cp.addLast("CompressionDecoder", strategy.left)); cp.addLast("CompressionDecoderEndStatistics", new CompressionDecoderEndStatisticsHandler()); - cp.addLast("ObjectDecoder", - new ObjectDecoder(Integer.MAX_VALUE, - ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader()))); + cp.addLast("FederatedFormatDetector", new FederatedFormatDetector()); cp.addLast("CompressionEncodingEndStatistics", new CompressionEncoderEndStatisticsHandler()); compressionStrategy.ifPresent(strategy -> cp.addLast("CompressionEncoder", strategy.right)); cp.addLast("CompressionEncodingStartStatistics", new CompressionEncoderStartStatisticsHandler()); - cp.addLast("ObjectEncoder", new ObjectEncoder()); - cp.addLast(FederationUtils.decoder(), new FederatedResponseEncoder()); + cp.addLast("ChunkedWriteHandler", new ChunkedWriteHandler()); + cp.addLast("FederatedResponseEncoder", new FederatedResponseEncoder()); + cp.addLast("FederatedFormatEncoder", new FederatedFormatEncoder()); cp.addLast(new FederatedWorkerHandler(_flt, _frc, _fan, networkTimer)); } }; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/network/FederatedMaxPayloadTest.java b/src/test/java/org/apache/sysds/test/functions/federated/network/FederatedMaxPayloadTest.java new file mode 100644 index 00000000000..0bcae9d51d4 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/network/FederatedMaxPayloadTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.network; + +import java.net.InetSocketAddress; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; + +import org.apache.sysds.runtime.controlprogram.federated.FederatedData; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; + +@Ignore("heavy: client+worker share one JVM, so it holds two ~2GB matrix copies. Needs a ~9GB fork " + + "heap; -DargLine is ignored here (pom uses @{argLine}), so bump the pom 'argLine' property " + + "(e.g. -Xmx9g) to run manually. Verified green: 2.158GB streamed end-to-end past the 2GB Netty cap.") +public class FederatedMaxPayloadTest extends AutomatedTestBase { + + private final static String TEST_NAME = "FederatedMaxPayloadTest"; + private final static String TEST_DIR = "functions/federated/network/"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedMaxPayloadTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {""})); + } + + @Test + public void transferOverTwoGigabytePayload() { + int port = getRandomAvailablePort(); + startLocalFedWorkerThread(port); + try { + MatrixBlock mb = denseMatrixExceedingTwoGigabytes(); + InetSocketAddress address = new InetSocketAddress("localhost", port); + FederatedRequest request = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, 1, mb); + + Future response = FederatedData.executeFederatedOperation(address, request); + Assert.assertTrue("Network send was not successful.", response.get().isSuccessful()); + } + catch(ExecutionException e) { + Assert.fail("Federated transfer failed: " + e.getMessage()); + } + catch(Exception e) { + Assert.fail("Federated transfer failed: " + e.getMessage()); + } + finally { + FederatedData.clearFederatedWorkers(); + } + } + + private static MatrixBlock denseMatrixExceedingTwoGigabytes() { + int rows = 30000; + int cols = 8950; + MatrixBlock mb = new MatrixBlock(rows, cols, false); + mb.allocateDenseBlock(); + mb.setNonZeros((long) rows * cols); + return mb; + } +}