001    /*
002     * Copyright (c) 2007-2014 Concurrent, Inc. All Rights Reserved.
003     *
004     * Project and contact information: http://www.cascading.org/
005     *
006     * This file is part of the Cascading project.
007     *
008     * Licensed under the Apache License, Version 2.0 (the "License");
009     * you may not use this file except in compliance with the License.
010     * You may obtain a copy of the License at
011     *
012     *     http://www.apache.org/licenses/LICENSE-2.0
013     *
014     * Unless required by applicable law or agreed to in writing, software
015     * distributed under the License is distributed on an "AS IS" BASIS,
016     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
017     * See the License for the specific language governing permissions and
018     * limitations under the License.
019     */
020    
021    package cascading.flow.hadoop.util;
022    
023    import java.io.ByteArrayInputStream;
024    import java.io.ByteArrayOutputStream;
025    import java.io.IOException;
026    import java.io.ObjectInputStream;
027    import java.io.ObjectOutputStream;
028    import java.io.ObjectStreamClass;
029    import java.io.Serializable;
030    import java.util.ArrayList;
031    import java.util.HashMap;
032    import java.util.List;
033    import java.util.Map;
034    import java.util.zip.GZIPInputStream;
035    import java.util.zip.GZIPOutputStream;
036    
037    import cascading.flow.FlowException;
038    
039    /** Class JavaObjectSerializer is the default implementation of {@link ObjectSerializer}. */
040    public class JavaObjectSerializer implements ObjectSerializer
041      {
042      @Override
043      public <T> byte[] serialize( T object, boolean compress ) throws IOException
044        {
045        if( object instanceof Map )
046          return serializeMap( (Map<String, ?>) object, compress );
047    
048        if( object instanceof List )
049          return serializeList( (List<?>) object, compress );
050    
051        ByteArrayOutputStream bytes = new ByteArrayOutputStream();
052    
053        ObjectOutputStream out = new ObjectOutputStream( compress ? new GZIPOutputStream( bytes ) : bytes );
054    
055        try
056          {
057          out.writeObject( object );
058          }
059        finally
060          {
061          out.close();
062          }
063    
064        return bytes.toByteArray();
065        }
066    
067      @Override
068      public <T> T deserialize( byte[] bytes, Class<T> type, boolean decompress ) throws IOException
069        {
070    
071        if( Map.class.isAssignableFrom( type ) )
072          return (T) deserializeMap( bytes, decompress );
073    
074        if( List.class.isAssignableFrom( type ) )
075          {
076          return (T) deserializeList( bytes, decompress );
077          }
078    
079        ObjectInputStream in = null;
080    
081        try
082          {
083          ByteArrayInputStream byteStream = new ByteArrayInputStream( bytes );
084    
085          in = new ObjectInputStream( decompress ? new GZIPInputStream( byteStream ) : byteStream )
086          {
087          @Override
088          protected Class<?> resolveClass( ObjectStreamClass desc ) throws IOException, ClassNotFoundException
089            {
090            try
091              {
092              return Class.forName( desc.getName(), false, Thread.currentThread().getContextClassLoader() );
093              }
094            catch( ClassNotFoundException exception )
095              {
096              return super.resolveClass( desc );
097              }
098            }
099          };
100    
101          return (T) in.readObject();
102          }
103        catch( ClassNotFoundException exception )
104          {
105          throw new FlowException( "unable to deserialize data", exception );
106          }
107        finally
108          {
109          if( in != null )
110            in.close();
111          }
112        }
113    
114      @Override
115      public <T> boolean accepts( Class<T> type )
116        {
117        return Serializable.class.isAssignableFrom( type )
118          || Map.class.isAssignableFrom( type )
119          || List.class.isAssignableFrom( type );
120        }
121    
122      public <T> byte[] serializeMap( Map<String, T> map, boolean compress ) throws IOException
123        {
124        ByteArrayOutputStream bytes = new ByteArrayOutputStream();
125        ObjectOutputStream out = new ObjectOutputStream( compress ? new GZIPOutputStream( bytes ) : bytes );
126    
127        Class<T> tClass;
128    
129        if( map.size() == 0 )
130          tClass = (Class<T>) Object.class;
131        else
132          tClass = (Class<T>) map.values().iterator().next().getClass();
133        try
134          {
135          out.writeInt( map.size() );
136          out.writeUTF( tClass.getName() );
137    
138          for( Map.Entry<String, T> entry : map.entrySet() )
139            {
140            out.writeUTF( entry.getKey() );
141            byte[] itemBytes = serialize( entry.getValue(), false );
142            out.writeInt( itemBytes.length );
143            out.write( itemBytes );
144            }
145          }
146        finally
147          {
148          out.close();
149          }
150    
151        return bytes.toByteArray();
152        }
153    
154      public <T> Map<String, T> deserializeMap( byte[] bytes, boolean decompress ) throws IOException
155        {
156        ObjectInputStream in = null;
157    
158        try
159          {
160          ByteArrayInputStream byteStream = new ByteArrayInputStream( bytes );
161    
162          in = new ObjectInputStream( decompress ? new GZIPInputStream( byteStream ) : byteStream );
163    
164          int mapSize = in.readInt();
165          Class<T> tClass = (Class<T>) Class.forName( in.readUTF() );
166    
167          Map<String, T> map = new HashMap<String, T>( mapSize );
168    
169          for( int j = 0; j < mapSize; j++ )
170            {
171            String key = in.readUTF();
172            byte[] valBytes = new byte[ in.readInt() ];
173            in.readFully( valBytes );
174            map.put( key, deserialize( valBytes, tClass, false ) );
175            }
176    
177          return map;
178          }
179        catch( ClassNotFoundException e )
180          {
181          throw new IOException( e );
182          }
183        finally
184          {
185          if( in != null )
186            in.close();
187          }
188        }
189    
190      public <T> byte[] serializeList( List<T> list, boolean compress ) throws IOException
191        {
192        ByteArrayOutputStream bytes = new ByteArrayOutputStream();
193    
194        ObjectOutputStream out = new ObjectOutputStream( compress ? new GZIPOutputStream( bytes ) : bytes );
195    
196        Class<T> tClass;
197    
198        if( list.size() == 0 )
199          tClass = (Class<T>) Object.class;
200        else
201          tClass = (Class<T>) list.get( 0 ).getClass();
202    
203        try
204          {
205          out.writeInt( list.size() );
206          out.writeUTF( tClass.getName() );
207    
208          for( T item : list )
209            {
210            byte[] itemBytes = serialize( item, false );
211            out.writeInt( itemBytes.length );
212            out.write( itemBytes );
213            }
214          }
215        finally
216          {
217          out.close();
218          }
219    
220        return bytes.toByteArray();
221        }
222    
223      public <T> List<T> deserializeList( byte[] bytes, boolean decompress ) throws IOException
224        {
225        ObjectInputStream in = null;
226    
227        try
228          {
229          ByteArrayInputStream byteStream = new ByteArrayInputStream( bytes );
230    
231          in = new ObjectInputStream( decompress ? new GZIPInputStream( byteStream ) : byteStream );
232    
233          int listSize = in.readInt();
234          Class<T> tClass = (Class<T>) Class.forName( in.readUTF() );
235    
236          List<T> list = new ArrayList<T>( listSize );
237    
238          for( int i = 0; i < listSize; i++ )
239            {
240            byte[] itemBytes = new byte[ in.readInt() ];
241            in.readFully( itemBytes );
242            list.add( deserialize( itemBytes, tClass, false ) );
243            }
244    
245          return list;
246          }
247        catch( ClassNotFoundException e )
248          {
249          throw new IOException( e );
250          }
251        finally
252          {
253          if( in != null )
254            in.close();
255          }
256        }
257      }