Uncompressing Data with the Deflate Algorithm in Java

The code shown below is an implementation of the Deflate algorithm written by Nayuki Minase, and slightly simplified by myself. As of this writing, the original version can be obtained at the URL http://nayuki.eigenstate.org/page/simple-deflate-implementation.

Specifically, I simplified it by squeezing all the classes into one file and removing a bunch of the comments. Heh. In all seriousness, Nayuki’s implemenation seems to be the most readable and understandable available out there. But I still have hopes of someday making it even clearer.

/*(MIT License)

Copyright © 2012 Nayuki Minase

Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:

* The above copyright notice and this permission notice shall be included in
  all copies or substantial portions of the Software.

* The Software is provided "as is", without warranty of any kind, express or
  implied, including but not limited to the warranties of merchantability,
  fitness for a particular purpose and noninfringement. In no event shall the
  authors or copyright holders be liable for any claim, damages or other
  liability, whether in an action of contract, tort or otherwise, arising from,
  out of or in connection with the Software or the use or other dealings in the
  Software.

*/

import java.io.*;
import java.util.*;
import java.math.*;

public class DeflateTest
{
	public static void main(String[] args) throws IOException 
	{	
		String fileNameToReadFrom = args[0];
		String fileNameToWriteTo = args[1];

		DataInputStream in = new DataInputStream
		(
			new BufferedInputStream
			(
				new FileInputStream
				(
					new File(fileNameToReadFrom)
				), 
				16 * 1024
			)
		);

		// Header
		byte[] b = new byte[10];
		in.readFully(b);
		if (b[0] != 0x1F || b[1] != (byte)0x8B)
			throw new RuntimeException("Invalid GZIP magic number");
		if (b[2] != 8)
			throw new RuntimeException("Unsupported compression method: " + (b[2] & 0xFF));
		int flags = b[3] & 0xFF;

		// Reserved flags
		if ((flags & 0xE0) != 0)
			throw new RuntimeException("Reserved flags are set");

		// Modification time
		int mtime = (b[4] & 0xFF) | (b[5] & 0xFF) << 8 | (b[6] & 0xFF) << 16 | (b[7] & 0xFF) << 24;
		if (mtime != 0)
		{
			//System.out.println("Last modified: " + new DateTime(1970, 1, 1).add(mtime * 1000000L));
		}
		else
		{
			System.out.println("Last modified: N/A");
		}

		// Extra flags
		switch (b[8] & 0xFF) 
		{
			case 2:   System.out.println("Extra flags: Maximum compression");  break;
			case 4:   System.out.println("Extra flags: Fastest compression");  break;
			default:  System.out.println("Extra flags: Unknown");  break;
		}

		// Operating system
		String os;
		switch (b[9] & 0xFF) 
		{
			case   0:  os = "FAT";             break;
			case   1:  os = "Amiga";           break;
			case   2:  os = "VMS";             break;
			case   3:  os = "Unix";            break;
			case   4:  os = "VM/CMS";          break;
			case   5:  os = "Atari TOS";       break;
			case   6:  os = "HPFS";            break;
			case   7:  os = "Macintosh";       break;
			case   8:  os = "Z-System";        break;
			case   9:  os = "CP/M";            break;
			case  10:  os = "TOPS-20";         break;
			case  11:  os = "NTFS";            break;
			case  12:  os = "QDOS";            break;
			case  13:  os = "Acorn RISCOS";    break;
			case 255:  os = "Unknown";         break;
			default :  os = "Really unknown";  break;
		}
		System.out.println("Operating system: " + os);

		// Text flag
		if ((flags & 0x01) != 0)
		{
			System.out.println("Flag: Text");
		}

		// Extra flag
		if ((flags & 0x04) != 0) 
		{
			System.out.println("Flag: Extra");
			b = new byte[2];
			in.readFully(b);
			int len = (b[0] & 0xFF) | (b[1] & 0xFF) << 8;
			in.readFully(new byte[len]);  // Skip extra data
		}

		// File name flag
		if ((flags & 0x08) != 0) 
		{
			StringBuilder sb = new StringBuilder();
			while (true) 
			{
				int temp = in.readByte();
				if (temp == -1)
					throw new EOFException();
				else if (temp == 0)  // Null-terminated string
					break;
				else
					sb.append((char)temp);
			}
			System.out.println("File name: " + sb.toString());
		}

		// Header CRC flag
		if ((flags & 0x02) != 0) {
			b = new byte[2];
			in.readFully(b);
			System.out.printf("Header CRC-16: %04X%n", (b[0] & 0xFF) | (b[1] & 0xFF) << 8);
		}

		// Comment flag
		if ((flags & 0x10) != 0) {
			StringBuilder sb = new StringBuilder();
			while (true) {
				int temp = in.readByte();
				if (temp == -1)
					throw new EOFException();
				else if (temp == 0)  // Null-terminated string
					break;
				else
					sb.append((char)temp);
			}
			System.out.println("Comment: " + sb.toString());
		}

		// Decompress
		byte[] bytesDecompressed = Decompressor.decompress
		(
			new ByteBitInputStream(in)
		);

		// Write decompressed data to output file
		OutputStream outputStream = new FileOutputStream(fileNameToWriteTo);
		outputStream.write(bytesDecompressed);
		outputStream.close();
	}
}

interface BitInputStream 
{	
	public int read() throws IOException;
	public int readNoEof() throws IOException;
	public void close() throws IOException;
	public int getBitPosition();
	public int readByte() throws IOException;
}

final class ByteBitInputStream implements BitInputStream 
{	
	private InputStream input;  
	private int nextBits;  
	private int bitPosition;  
	private boolean isEndOfStream;

	public ByteBitInputStream(InputStream in) 
	{
		if (in == null)
		{
			throw new NullPointerException("Argument is null");
		}
		input = in;
		bitPosition = 8;
		isEndOfStream = false;
	}

	public int read() throws IOException 
	{
		if (isEndOfStream)
		{
			return -1;
		}
		if (bitPosition == 8) 
		{
			nextBits = input.read();
			if (nextBits == -1) 
			{
				isEndOfStream = true;
				return -1;
			}
			bitPosition = 0;
		}
		int result = (nextBits >>> bitPosition) & 1;
		bitPosition++;

		return result;
	}

	public int readNoEof() throws IOException 
	{
		int result = read();
		if (result != -1)
		{
			return result;
		}
		else
		{
			throw new EOFException("End of stream reached");
		}
	}

	public int getBitPosition() 
	{
		return bitPosition % 8;
	}

	public int readByte() throws IOException 
	{
		bitPosition = 8;
		return input.read();
	}

	public void close() throws IOException 
	{
		input.close();
	}
}

final class CanonicalCode 
{
	// A canonical Huffman code.	
	private int[] codeLengths;

	public CanonicalCode(int[] codeLengths) 
	{
		if (codeLengths == null)
		{
			throw new NullPointerException("Argument is null");
		}

		this.codeLengths = codeLengths.clone();
		for (int x : codeLengths) 
		{
			if (x < 0)
			{
				throw new IllegalArgumentException("Illegal code length");
			}
		}
	}

	public CanonicalCode(CodeTree tree, int symbolLimit) 
	{
		codeLengths = new int[symbolLimit];
		buildCodeLengths(tree.root, 0);
	}

	private void buildCodeLengths(Node node, int depth) 
	{
		if (node instanceof InternalNode) 
		{
			InternalNode internalNode = (InternalNode)node;
			buildCodeLengths(internalNode.leftChild , depth + 1);
			buildCodeLengths(internalNode.rightChild, depth + 1);
		} 
		else if (node instanceof Leaf) 
		{
			int symbol = ((Leaf)node).symbol;
			if (codeLengths[symbol] != 0)
			{
				throw new AssertionError("Symbol has more than one code"); 
			}
			if (symbol >= codeLengths.length)
			{
				throw new IllegalArgumentException("Symbol exceeds symbol limit");
			}
			codeLengths[symbol] = depth;
		} 
		else 
		{
			throw new AssertionError("Illegal node type");
		}
	}

	public int getSymbolLimit() 
	{
		return codeLengths.length;
	}

	public int getCodeLength(int symbol) 
	{
		if (symbol < 0 || symbol >= codeLengths.length)
		{
			throw new IllegalArgumentException("Symbol out of range");
		}
		return codeLengths[symbol];
	}

	public CodeTree toCodeTree() 
	{
		List<Node> nodes = new ArrayList<Node>();
		for (int i = max(codeLengths); i >= 1; i--) 
		{  
			// Descend through positive code lengths
			List<Node> newNodes = new ArrayList<Node>();

			// Add leaves for symbols with code length i
			for (int j = 0; j < codeLengths.length; j++) {
				if (codeLengths[j] == i)
					newNodes.add(new Leaf(j));
			}

			// Merge nodes from the previous deeper layer
			for (int j = 0; j < nodes.size(); j += 2)
			{
				newNodes.add(new InternalNode(nodes.get(j), nodes.get(j + 1)));
			}

			nodes = newNodes;
			if (nodes.size() % 2 != 0)
			{
				throw new IllegalStateException("This canonical code does not represent a Huffman code tree");
			}
		}

		if (nodes.size() != 2)
		{
			throw new IllegalStateException("This canonical code does not represent a Huffman code tree");
		}
		return new CodeTree(new InternalNode(nodes.get(0), nodes.get(1)), codeLengths.length);
	}

	private static int max(int[] array) 
	{
		int result = array[0];
		for (int x : array)
		{
			result = Math.max(x, result);
		}
		return result;
	}	
}

final class CircularDictionary 
{
	private byte[] data;	
	private int index;
	private int mask;

	public CircularDictionary(int size) 
	{
		data = new byte[size];
		index = 0;

		if (IntegerMath.isPowerOf2(size))
		{
			mask = size - 1;
		}
		else
		{
			mask = 0;
		}
	}

	public void append(int b) 
	{
		data[index] = (byte)b;
		if (mask != 0)
			index = (index + 1) & mask;
		else
			index = (index + 1) % data.length;
	}

	public void copy(int dist, int len, OutputStream out) throws IOException 
	{
		if (len < 0 || dist < 1 || dist > data.length)
		{
			throw new IllegalArgumentException();
		}

		if (mask != 0) 
		{
			int readIndex = (index - dist + data.length) & mask;
			for (int i = 0; i < len; i++) 
			{
				out.write(data[readIndex]);
				data[index] = data[readIndex];
				readIndex = (readIndex + 1) & mask;
				index = (index + 1) & mask;
			}
		} 
		else 
		{
			int readIndex = (index - dist + data.length) % data.length;
			for (int i = 0; i < len; i++) 
			{
				out.write(data[readIndex]);
				data[index] = data[readIndex];
				readIndex = (readIndex + 1) % data.length;
				index = (index + 1) % data.length;
			}
		}
	}	
}

final class CodeTree 
{
	public final InternalNode root;  // Not null

	// Stores the code for each symbol, or null if the symbol has no code.
	// For example, if symbol 5 has code 10011, then codes.get(5) is the list [1, 0, 0, 1, 1].
	private List<List<Integer>> codes;

	// Every symbol in the tree 'root' must be strictly less than 'symbolLimit'.
	public CodeTree(InternalNode root, int symbolLimit) 
	{
		if (root == null)
		{
			throw new NullPointerException("Argument is null");
		}
		this.root = root;

		codes = new ArrayList<List<Integer>>();  // Initially all null
		for (int i = 0; i < symbolLimit; i++)
		{
			codes.add(null);
		}

		buildCodeList(root, new ArrayList<Integer>());  // Fills 'codes' with appropriate data
	}

	private void buildCodeList(Node node, List<Integer> prefix) 
	{
		if (node instanceof InternalNode) 
		{
			InternalNode internalNode = (InternalNode)node;

			prefix.add(0);
			buildCodeList(internalNode.leftChild , prefix);
			prefix.remove(prefix.size() - 1);

			prefix.add(1);
			buildCodeList(internalNode.rightChild, prefix);
			prefix.remove(prefix.size() - 1);	
		} 
		else if (node instanceof Leaf) 
		{
			Leaf leaf = (Leaf)node;
			if (leaf.symbol >= codes.size())
			{
				throw new IllegalArgumentException("Symbol exceeds symbol limit");
			}
			if (codes.get(leaf.symbol) != null)
			{
				throw new IllegalArgumentException("Symbol has more than one code");
			}
			codes.set(leaf.symbol, new ArrayList<Integer>(prefix));			
		} 
		else 
		{
			throw new AssertionError("Illegal node type");
		}
	}

	public List<Integer> getCode(int symbol) 
	{
		if (symbol < 0)
		{
			throw new IllegalArgumentException("Illegal symbol");
		}
		else if (codes.get(symbol) == null)
		{
			throw new IllegalArgumentException("No code for given symbol");
		}
		else
		{
			return codes.get(symbol);
		}
	}

	// Returns a string showing all the codes in this tree. The format is subject to change. Useful for debugging.
	public String toString() 
	{
		StringBuilder sb = new StringBuilder();
		toString("", root, sb);
		return sb.toString();
	}

	private static void toString(String prefix, Node node, StringBuilder sb) 
	{
		if (node instanceof InternalNode) 
		{
			InternalNode internalNode = (InternalNode)node;
			toString(prefix + "0", internalNode.leftChild , sb);
			toString(prefix + "1", internalNode.rightChild, sb);
		} 
		else if (node instanceof Leaf) 
		{
			sb.append(String.format("Code %s: Symbol %d%n", prefix, ((Leaf)node).symbol));
		} 
		else 
		{
			throw new AssertionError("Illegal node type");
		}
	}	
}

abstract class Node 
{	
	public Node() {}
}

final class InternalNode extends Node 
{
	public final Node leftChild;  // Not null	
	public final Node rightChild;  // Not null

	public InternalNode(Node leftChild, Node rightChild) 
	{
		if (leftChild == null || rightChild == null)
		{
			throw new NullPointerException("Argument is null");
		}
		this.leftChild = leftChild;
		this.rightChild = rightChild;
	}
}

final class Leaf extends Node 
{	
	public final int symbol;

	public Leaf(int symbol) 
	{
		if (symbol < 0)
			throw new IllegalArgumentException("Illegal symbol value");
		this.symbol = symbol;
	}
}

final class Decompressor 
{	
	/* Public method */
	public static byte[] decompress(BitInputStream in) throws IOException 
	{
		Decompressor decompressor = new Decompressor(in);
		return decompressor.output.toByteArray();
	}

	/* Private members */
	private BitInputStream input;
	private ByteArrayOutputStream output;
	private CircularDictionary dictionary;

	private Decompressor(BitInputStream in) throws IOException 
	{
		input = in;
		output = new ByteArrayOutputStream();
		dictionary = new CircularDictionary(32 * 1024);

		// Process the stream of blocks
		while (true) 
		{
			// Block header
			boolean isFinal = in.readNoEof() == 1;  // bfinal
			int type = readInt(2);                  // btype

			// Decompress by type
			if (type == 0)
			{
				decompressUncompressedBlock();
			}
			else if (type == 1 || type == 2) 
			{
				CodeTree litLenCode, distCode;
				if (type == 1) 
				{
					litLenCode = fixedLiteralLengthCode;
					distCode = fixedDistanceCode;
				} 
				else 
				{
					CodeTree[] temp = decodeHuffmanCodes(in);
					litLenCode = temp[0];
					distCode = temp[1];
				}

				decompressHuffmanBlock(litLenCode, distCode);	
			} 
			else if (type == 3)
			{
				throw new FormatException("Invalid block type");
			}
			else
			{
				throw new AssertionError();
			}

			if (isFinal)
			{
				break;
			}
		}
	}

	// For handling static Huffman codes (btype = 1)
	private static CodeTree fixedLiteralLengthCode;
	private static CodeTree fixedDistanceCode;

	static 
	{
		int[] llcodelens = new int[288];
		Arrays.fill(llcodelens,   0, 144, 8);
		Arrays.fill(llcodelens, 144, 256, 9);
		Arrays.fill(llcodelens, 256, 280, 7);
		Arrays.fill(llcodelens, 280, 288, 8);
		fixedLiteralLengthCode = new CanonicalCode(llcodelens).toCodeTree();

		int[] distcodelens = new int[32];
		Arrays.fill(distcodelens, 5);
		fixedDistanceCode = new CanonicalCode(distcodelens).toCodeTree();
	}

	// For handling dynamic Huffman codes (btype = 2)
	private CodeTree[] decodeHuffmanCodes(BitInputStream in) throws IOException 
	{
		int numLitLenCodes = readInt(5) + 257;  // hlit  + 257
		int numDistCodes = readInt(5) + 1;      // hdist +   1

		int numCodeLenCodes = readInt(4) + 4;   // hclen +   4
		int[] codeLenCodeLen = new int[19];
		codeLenCodeLen[16] = readInt(3);
		codeLenCodeLen[17] = readInt(3);
		codeLenCodeLen[18] = readInt(3);
		codeLenCodeLen[ 0] = readInt(3);
		for (int i = 0; i < numCodeLenCodes - 4; i++) 
		{
			if (i % 2 == 0)
				codeLenCodeLen[8 + i / 2] = readInt(3);
			else
				codeLenCodeLen[7 - i / 2] = readInt(3);
		}
		CodeTree codeLenCode = new CanonicalCode(codeLenCodeLen).toCodeTree();

		int[] codeLens = new int[numLitLenCodes + numDistCodes];
		int runVal = -1;
		int runLen = 0;
		for (int i = 0; i < codeLens.length; i++) 
		{
			if (runLen > 0) 
			{
				codeLens[i] = runVal;
				runLen--;	
			} 
			else 
			{
				int sym = decodeSymbol(codeLenCode);
				if (sym < 16) 
				{
					codeLens[i] = sym;
					runVal = sym;
				} 
				else 
				{
					if (sym == 16) 
					{
						if (runVal == -1)
						{
							throw new FormatException("No code length value to copy");
						}
						runLen = readInt(2) + 3;
					} 
					else if (sym == 17) 
					{
						runVal = 0;
						runLen = readInt(3) + 3;
					} 
					else if (sym == 18) 
					{
						runVal = 0;
						runLen = readInt(7) + 11;
					} 
					else
					{
						throw new AssertionError();
					}

					i--;
				}
			}
		}
		if (runLen > 0)
		{
			throw new FormatException("Run exceeds number of codes");
		}

		// Create code trees
		int[] litLenCodeLen = Arrays.copyOf(codeLens, numLitLenCodes);
		CodeTree litLenCode = new CanonicalCode(litLenCodeLen).toCodeTree();

		int[] distCodeLen = Arrays.copyOfRange(codeLens, numLitLenCodes, codeLens.length);
		CodeTree distCode;
		if (distCodeLen.length == 1 && distCodeLen[0] == 0)
		{
			distCode = null;  // Empty distance code; the block shall be all literal symbols
		}
		else
		{
			distCode = new CanonicalCode(distCodeLen).toCodeTree();
		}

		return new CodeTree[]{litLenCode, distCode};
	}

	/* Block decompression methods */

	private void decompressUncompressedBlock() throws IOException 
	{
		// Discard bits to align to byte boundary
		while (input.getBitPosition() != 0)
		{
			input.readNoEof();
		}

		// Read length
		int len  = readInt(16);
		int nlen = readInt(16);
		if ((len ^ 0xFFFF) != nlen)
		{
			throw new FormatException("Invalid length in uncompressed block");
		}

		// Copy bytes
		for (int i = 0; i < len; i++) 
		{
			int temp = input.readByte();
			if (temp == -1)
			{
				throw new EOFException();
			}
			output.write(temp);
			dictionary.append(temp);
		}
	}

	private void decompressHuffmanBlock(CodeTree litLenCode, CodeTree distCode) throws IOException 
	{
		if (litLenCode == null)
		{
			throw new NullPointerException();
		}

		while (true) 
		{
			int sym = decodeSymbol(litLenCode);
			if (sym == 256)  // End of block
			{
				break;
			}

			if (sym < 256) 
			{  
				// Literal byte
				output.write(sym);
				dictionary.append(sym);
			} 
			else 
			{  // Length and distance for copying
				int len = decodeRunLength(sym);
				if (distCode == null)
				{
					throw new FormatException("Length symbol encountered with empty distance code");
				}
				int distSym = decodeSymbol(distCode);
				int dist = decodeDistance(distSym);
				dictionary.copy(dist, len, output);
			}
		}
	}

	/* Symbol decoding methods */
	private int decodeSymbol(CodeTree code) throws IOException 
	{
		InternalNode currentNode = code.root;
		while (true) 
		{
			int temp = input.readNoEof();
			Node nextNode;
			if (temp == 0)
			{			
				nextNode = currentNode.leftChild;
			}
			else if (temp == 1)
			{
				nextNode = currentNode.rightChild;
			}
			else
			{
				throw new AssertionError();
			}

			if (nextNode instanceof Leaf)
			{
				return ((Leaf)nextNode).symbol;
			}
			else if (nextNode instanceof InternalNode)
			{
				currentNode = (InternalNode)nextNode;
			}
			else
			{
				throw new AssertionError();
			}
		}
	}

	private int decodeRunLength(int sym) throws IOException 
	{
		if (sym < 257 || sym > 285)
		{
			throw new FormatException("Invalid run length symbol: " + sym);
		}
		else if (sym <= 264)
		{
			return sym - 254;
		}
		else if (sym <= 284) 
		{
			int i = (sym - 261) / 4;  // Number of extra bits to read
			return (((sym - 265) % 4 + 4) << i) + 3 + readInt(i);
		} 
		else  // sym == 285
		{
			return 258;
		}
	}

	private int decodeDistance(int sym) throws IOException 
	{
		if (sym <= 3)
		{
			return sym + 1;
		}
		else if (sym <= 29) 
		{
			int i = sym / 2 - 1;  // Number of extra bits to read
			return ((sym % 2 + 2) << i) + 1 + readInt(i);
		} 
		else
		{
			throw new FormatException("Invalid distance symbol: " + sym);
		}
	}

	/* Utility method */
	private int readInt(int numBits) throws IOException 
	{
		if (numBits < 0 || numBits >= 32)
		{
			throw new IllegalArgumentException();
		}

		int result = 0;
		for (int i = 0; i < numBits; i++)
		{
			result |= input.readNoEof() << i;
		}
		return result;
	}
}

@SuppressWarnings("serial")
class FormatException extends RuntimeException {

	public FormatException() 
	{
		super();
	}

	public FormatException(String msg) 
	{
		super(msg);
	}
}

class IntegerMath
{
	public static boolean isPowerOf2(int valueToCheck)
	{
		return ((Math.log(valueToCheck) / Math.log(2)) % 0) == 0;
	}
}
This entry was posted in Uncategorized and tagged , , , . Bookmark the permalink.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s