001    /**
002     * Copyright (C) 2012 FuseSource, Inc.
003     * http://fusesource.com
004     *
005     * Licensed under the Apache License, Version 2.0 (the "License");
006     * you may not use this file except in compliance with the License.
007     * You may obtain a copy of the License at
008     *
009     *    http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.fusesource.hawtdispatch.transport;
019    
020    import org.fusesource.hawtbuf.Buffer;
021    import org.fusesource.hawtbuf.DataByteArrayOutputStream;
022    import org.fusesource.hawtdispatch.util.BufferPool;
023    import org.fusesource.hawtdispatch.util.BufferPools;
024    
025    import java.io.EOFException;
026    import java.io.IOException;
027    import java.net.ProtocolException;
028    import java.net.SocketException;
029    import java.nio.ByteBuffer;
030    import java.nio.channels.GatheringByteChannel;
031    import java.nio.channels.ReadableByteChannel;
032    import java.nio.channels.SocketChannel;
033    import java.nio.channels.WritableByteChannel;
034    import java.util.Arrays;
035    import java.util.LinkedList;
036    
037    /**
038     * Provides an abstract base class to make implementing the ProtocolCodec interface
039     * easier.
040     *
041     * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
042     */
043    public abstract class AbstractProtocolCodec implements ProtocolCodec, TransportAware {
044    
045        protected BufferPools bufferPools;
046        protected BufferPool writeBufferPool;
047        protected BufferPool readBufferPool;
048    
049        protected int writeBufferSize = 1024 * 64;
050        protected long writeCounter = 0L;
051        protected GatheringByteChannel writeChannel = null;
052        protected DataByteArrayOutputStream nextWriteBuffer;
053        protected long lastWriteIoSize = 0;
054    
055        protected LinkedList<ByteBuffer> writeBuffer = new LinkedList<ByteBuffer>();
056        private long writeBufferRemaining = 0;
057    
058    
059        public static interface Action {
060            Object apply() throws IOException;
061        }
062    
063        protected long readCounter = 0L;
064        protected int readBufferSize = 1024 * 64;
065        protected ReadableByteChannel readChannel = null;
066        protected ByteBuffer readBuffer;
067        protected ByteBuffer directReadBuffer = null;
068    
069        protected int readEnd;
070        protected int readStart;
071        protected int lastReadIoSize;
072        protected Action nextDecodeAction;
073    
074        public void setTransport(Transport transport) {
075            if( transport instanceof TcpTransport) {
076                TcpTransport tcp = (TcpTransport) transport;
077                writeBufferSize = tcp.getSendBufferSize();
078                readBufferSize = tcp.getReceiveBufferSize();
079            } else if( transport instanceof UdpTransport) {
080                UdpTransport tcp = (UdpTransport) transport;
081                writeBufferSize = tcp.getSendBufferSize();
082                readBufferSize = tcp.getReceiveBufferSize();
083            } else {
084                try {
085                    if (this.writeChannel instanceof SocketChannel) {
086                        writeBufferSize = ((SocketChannel) this.writeChannel).socket().getSendBufferSize();
087                        readBufferSize = ((SocketChannel) this.readChannel).socket().getReceiveBufferSize();
088                    } else if (this.writeChannel instanceof SslTransport.SSLChannel) {
089                        writeBufferSize = ((SslTransport.SSLChannel) this.readChannel).socket().getSendBufferSize();
090                        readBufferSize = ((SslTransport.SSLChannel) this.writeChannel).socket().getReceiveBufferSize();
091                    }
092                } catch (SocketException ignore) {
093                }
094            }
095            if( bufferPools!=null ) {
096                readBufferPool = bufferPools.getBufferPool(readBufferSize);
097                writeBufferPool = bufferPools.getBufferPool(writeBufferSize);
098            }
099        }
100    
101        public void setWritableByteChannel(WritableByteChannel channel) throws SocketException {
102            this.writeChannel = (GatheringByteChannel) channel;
103        }
104    
105        public int getReadBufferSize() {
106            return readBufferSize;
107        }
108    
109        public int getWriteBufferSize() {
110            return writeBufferSize;
111        }
112    
113        public boolean full() {
114            return writeBufferRemaining >= writeBufferSize;
115        }
116    
117        public boolean isEmpty() {
118            return writeBufferRemaining == 0 && (nextWriteBuffer==null || nextWriteBuffer.size() == 0);
119        }
120    
121        public long getWriteCounter() {
122            return writeCounter;
123        }
124    
125        public long getLastWriteSize() {
126            return lastWriteIoSize;
127        }
128    
129        abstract protected void encode(Object value) throws IOException;
130    
131        public ProtocolCodec.BufferState write(Object value) throws IOException {
132            if (full()) {
133                return ProtocolCodec.BufferState.FULL;
134            } else {
135                boolean wasEmpty = isEmpty();
136                if( nextWriteBuffer == null ) {
137                    nextWriteBuffer = allocateNextWriteBuffer();
138                }
139                encode(value);
140                if (nextWriteBuffer.size() >= (writeBufferSize* 0.75)) {
141                    flushNextWriteBuffer();
142                }
143                if (wasEmpty) {
144                    return ProtocolCodec.BufferState.WAS_EMPTY;
145                } else {
146                    return ProtocolCodec.BufferState.NOT_EMPTY;
147                }
148            }
149        }
150    
151        private DataByteArrayOutputStream allocateNextWriteBuffer() {
152            if( writeBufferPool !=null ) {
153                return new DataByteArrayOutputStream(writeBufferPool.checkout()) {
154                    @Override
155                    protected void resize(int newcount) {
156                        byte[] oldbuf = buf;
157                        super.resize(newcount);
158                        if( oldbuf.length == writeBufferPool.getBufferSize() ) {
159                            writeBufferPool.checkin(oldbuf);
160                        }
161                    }
162                };
163            } else {
164                return new DataByteArrayOutputStream(writeBufferSize);
165            }
166        }
167    
168        protected void writeDirect(ByteBuffer value) throws IOException {
169            // is the direct buffer small enough to just fit into the nextWriteBuffer?
170            int nextnextPospos = nextWriteBuffer.position();
171            int valuevalueLengthlength = value.remaining();
172            int available = nextWriteBuffer.getData().length - nextnextPospos;
173            if (available > valuevalueLengthlength) {
174                value.get(nextWriteBuffer.getData(), nextnextPospos, valuevalueLengthlength);
175                nextWriteBuffer.position(nextnextPospos + valuevalueLengthlength);
176            } else {
177                if (nextWriteBuffer!=null && nextWriteBuffer.size() != 0) {
178                    flushNextWriteBuffer();
179                }
180                writeBuffer.add(value);
181                writeBufferRemaining += value.remaining();
182            }
183        }
184    
185        protected void flushNextWriteBuffer() {
186            DataByteArrayOutputStream next = allocateNextWriteBuffer();
187            ByteBuffer bb = nextWriteBuffer.toBuffer().toByteBuffer();
188            writeBuffer.add(bb);
189            writeBufferRemaining += bb.remaining();
190            nextWriteBuffer = next;
191        }
192    
193        public ProtocolCodec.BufferState flush() throws IOException {
194            while (true) {
195                if (writeBufferRemaining != 0) {
196                    if( writeBuffer.size() == 1) {
197                        ByteBuffer b = writeBuffer.getFirst();
198                        lastWriteIoSize = writeChannel.write(b);
199                        if (lastWriteIoSize == 0) {
200                            return ProtocolCodec.BufferState.NOT_EMPTY;
201                        } else {
202                            writeBufferRemaining -= lastWriteIoSize;
203                            writeCounter += lastWriteIoSize;
204                            if(!b.hasRemaining()) {
205                                onBufferFlushed(writeBuffer.removeFirst());
206                            }
207                        }
208                    } else {
209                        ByteBuffer[] buffers = writeBuffer.toArray(new ByteBuffer[writeBuffer.size()]);
210                        lastWriteIoSize = writeChannel.write(buffers, 0, buffers.length);
211                        if (lastWriteIoSize == 0) {
212                            return ProtocolCodec.BufferState.NOT_EMPTY;
213                        } else {
214                            writeBufferRemaining -= lastWriteIoSize;
215                            writeCounter += lastWriteIoSize;
216                            while (!writeBuffer.isEmpty() && !writeBuffer.getFirst().hasRemaining()) {
217                                onBufferFlushed(writeBuffer.removeFirst());
218                            }
219                        }
220                    }
221                } else {
222                    if (nextWriteBuffer==null || nextWriteBuffer.size() == 0) {
223                        if( writeBufferPool!=null &&  nextWriteBuffer!=null ) {
224                            writeBufferPool.checkin(nextWriteBuffer.getData());
225                            nextWriteBuffer = null;
226                        }
227                        return ProtocolCodec.BufferState.EMPTY;
228                    } else {
229                        flushNextWriteBuffer();
230                    }
231                }
232            }
233        }
234    
235        /**
236         * Called when a buffer is flushed out.  Subclasses can implement
237         * in case they want to recycle the buffer.
238         *
239         * @param byteBuffer
240         */
241        protected void onBufferFlushed(ByteBuffer byteBuffer) {
242        }
243    
244        /////////////////////////////////////////////////////////////////////
245        //
246        // Non blocking read impl
247        //
248        /////////////////////////////////////////////////////////////////////
249    
250        abstract protected Action initialDecodeAction();
251    
252    
253        public void setReadableByteChannel(ReadableByteChannel channel) throws SocketException {
254            this.readChannel = channel;
255            if( nextDecodeAction==null ) {
256                nextDecodeAction = initialDecodeAction();
257            }
258        }
259    
260        public void unread(byte[] buffer) {
261            assert ((readCounter == 0));
262            readBuffer = ByteBuffer.allocate(buffer.length);
263            readBuffer.put(buffer);
264            readCounter += buffer.length;
265        }
266    
267        public long getReadCounter() {
268            return readCounter;
269        }
270    
271        public long getLastReadSize() {
272            return lastReadIoSize;
273        }
274    
275        public Object read() throws IOException {
276            Object command = null;
277            while (command == null) {
278                if (directReadBuffer != null) {
279                    while (directReadBuffer.hasRemaining()) {
280                        lastReadIoSize = readChannel.read(directReadBuffer);
281                        readCounter += lastReadIoSize;
282                        if (lastReadIoSize == -1) {
283                            throw new EOFException("Peer disconnected");
284                        } else if (lastReadIoSize == 0) {
285                            return null;
286                        }
287                    }
288                    command = nextDecodeAction.apply();
289                } else {
290                    if (readBuffer==null || readEnd == readBuffer.position()) {
291    
292                        if (readBuffer==null || readBuffer.remaining() == 0) {
293                            int size = readEnd - readStart;
294                            int newCapacity = 0;
295                            if (readStart == 0) {
296                                newCapacity = size + readBufferSize;
297                            } else {
298                                if (size > readBufferSize) {
299                                    newCapacity = size + readBufferSize;
300                                } else {
301                                    newCapacity = readBufferSize;
302                                }
303                            }
304    
305                            byte[] newBuffer;
306                            if (size > 0) {
307                                newBuffer = Arrays.copyOfRange(readBuffer.array(), readStart, readStart + newCapacity);
308                            } else {
309                                if( readBufferPool!=null) {
310                                    if (newCapacity == readBufferPool.getBufferSize()) {
311                                        newBuffer = readBufferPool.checkout();
312                                    } else {
313                                        newBuffer =  new byte[newCapacity];
314                                    }
315                                } else {
316                                    if (size > 0) {
317                                        newBuffer = Arrays.copyOfRange(readBuffer.array(), readStart, readStart + newCapacity);
318                                    } else {
319                                        newBuffer =  new byte[newCapacity];
320                                    }
321                                }
322                            }
323                            readBuffer = ByteBuffer.wrap(newBuffer);
324                            readBuffer.position(size);
325                            readStart = 0;
326                            readEnd = size;
327                        }
328                        int p = readBuffer.position();
329                        lastReadIoSize = readChannel.read(readBuffer);
330                        readCounter += lastReadIoSize;
331                        if (lastReadIoSize == -1) {
332                            readCounter += 1; // to compensate for that -1
333                            throw new EOFException("Peer disconnected");
334                        } else if (lastReadIoSize == 0) {
335                            if (readBufferPool != null && readStart == readEnd) {
336                                if (readEnd == 0 && readBuffer.array().length == readBufferPool.getBufferSize()) {
337                                    readBufferPool.checkin(readBuffer.array());
338                                } else {
339                                    readStart = 0;
340                                    readEnd = 0;
341                                }
342                                readBuffer = null;
343                            }
344                            return null;
345                        }
346                    }
347                    command = nextDecodeAction.apply();
348                    assert ((readStart <= readEnd));
349                    assert ((readEnd <= readBuffer.position()));
350                }
351            }
352            return command;
353        }
354    
355        protected Buffer readUntil(Byte octet) throws ProtocolException {
356            return readUntil(octet, -1);
357        }
358    
359        protected Buffer readUntil(Byte octet, int max) throws ProtocolException {
360            return readUntil(octet, max, "Maximum protocol buffer length exeeded");
361        }
362    
363        protected Buffer readUntil(Byte octet, int max, String msg) throws ProtocolException {
364            byte[] array = readBuffer.array();
365            Buffer buf = new Buffer(array, readEnd, readBuffer.position() - readEnd);
366            int pos = buf.indexOf(octet);
367            if (pos >= 0) {
368                int offset = readStart;
369                readEnd += pos + 1;
370                readStart = readEnd;
371                int length = readEnd - offset;
372                if (max >= 0 && length > max) {
373                    throw new ProtocolException(msg);
374                }
375                return new Buffer(array, offset, length);
376            } else {
377                readEnd += buf.length;
378                if (max >= 0 && (readEnd - readStart) > max) {
379                    throw new ProtocolException(msg);
380                }
381                return null;
382            }
383        }
384    
385        protected Buffer readBytes(int length) {
386            if ((readBuffer.position() - readStart) < length) {
387                readEnd = readBuffer.position();
388                return null;
389            } else {
390                int offset = readStart;
391                readEnd = offset + length;
392                readStart = readEnd;
393                return new Buffer(readBuffer.array(), offset, length);
394            }
395        }
396    
397        protected Buffer peekBytes(int length) {
398            if ((readBuffer.position() - readStart) < length) {
399                readEnd = readBuffer.position();
400                return null;
401            } else {
402                return new Buffer(readBuffer.array(), readStart, length);
403            }
404        }
405    
406        protected Boolean readDirect(ByteBuffer buffer) {
407            assert (directReadBuffer == null || (directReadBuffer == buffer));
408    
409            if (buffer.hasRemaining()) {
410                // First we need to transfer the read bytes from the non-direct
411                // byte buffer into the direct one..
412                int limit = readBuffer.position();
413                int transferSize = Math.min((limit - readStart), buffer.remaining());
414                byte[] readBufferArray = readBuffer.array();
415                buffer.put(readBufferArray, readStart, transferSize);
416    
417                // The direct byte buffer might have been smaller than our readBuffer one..
418                // compact the readBuffer to avoid doing additional mem allocations.
419                int trailingSize = limit - (readStart + transferSize);
420                if (trailingSize > 0) {
421                    System.arraycopy(readBufferArray, readStart + transferSize, readBufferArray, readStart, trailingSize);
422                }
423                readBuffer.position(readStart + trailingSize);
424            }
425    
426            // For big direct byte buffers, it will still not have been filled,
427            // so install it so that we directly read into it until it is filled.
428            if (buffer.hasRemaining()) {
429                directReadBuffer = buffer;
430                return false;
431            } else {
432                directReadBuffer = null;
433                buffer.flip();
434                return true;
435            }
436        }
437    
438        public BufferPools getBufferPools() {
439            return bufferPools;
440        }
441    
442        public void setBufferPools(BufferPools bufferPools) {
443            this.bufferPools = bufferPools;
444            if( bufferPools!=null ) {
445                readBufferPool = bufferPools.getBufferPool(readBufferSize);
446                writeBufferPool = bufferPools.getBufferPool(writeBufferSize);
447            } else {
448                readBufferPool = null;
449                writeBufferPool = null;
450            }
451        }
452    }