diff --git a/scripts/install-samtools.sh b/scripts/install-samtools.sh index 84feb30700..97238f6d2f 100755 --- a/scripts/install-samtools.sh +++ b/scripts/install-samtools.sh @@ -1,5 +1,6 @@ #!/bin/sh set -ex wget https://github.com/samtools/samtools/releases/download/1.14/samtools-1.14.tar.bz2 +# Note that the CRAM Interop Tests are dependent on the test files in samtools-1.14/htslib-1.14/htscodecs/tests/dat tar -xjvf samtools-1.14.tar.bz2 -cd samtools-1.14 && ./configure --prefix=/usr && make && sudo make install +cd samtools-1.14 && ./configure --prefix=/usr && make && sudo make install \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/CompressionUtils.java b/src/main/java/htsjdk/samtools/cram/compression/CompressionUtils.java new file mode 100644 index 0000000000..d4d1408448 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/CompressionUtils.java @@ -0,0 +1,177 @@ +package htsjdk.samtools.cram.compression; + +import htsjdk.samtools.cram.CRAMException; +import htsjdk.samtools.cram.compression.rans.Constants; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +public class CompressionUtils { + public static void writeUint7(final int i, final ByteBuffer cp) { + int s = 0; + int X = i; + do { + s += 7; + X >>= 7; + } while (X > 0); + do { + s -= 7; + //writeByte + final int s_ = (s > 0) ? 1 : 0; + cp.put((byte) (((i >> s) & 0x7f) + (s_ << 7))); + } while (s > 0); + } + + public static int readUint7(final ByteBuffer cp) { + int i = 0; + int c; + do { + //read byte + c = cp.get(); + i = (i << 7) | (c & 0x7f); + } while ((c & 0x80) != 0); + return i; + } + + public static ByteBuffer encodePack( + final ByteBuffer inBuffer, + final ByteBuffer outBuffer, + final int[] frequencyTable, + final int[] packMappingTable, + final int numSymbols){ + final int inSize = inBuffer.remaining(); + final ByteBuffer encodedBuffer; + if (numSymbols <= 1) { + encodedBuffer = CompressionUtils.allocateByteBuffer(0); + } else if (numSymbols <= 2) { + + // 1 bit per value + final int encodedBufferSize = (int) Math.ceil((double) inSize/8); + encodedBuffer = CompressionUtils.allocateByteBuffer(encodedBufferSize); + int j = -1; + for (int i = 0; i < inSize; i ++) { + if (i % 8 == 0) { + encodedBuffer.put(++j, (byte) 0); + } + encodedBuffer.put(j, (byte) (encodedBuffer.get(j) + (packMappingTable[inBuffer.get(i) & 0xFF] << (i % 8)))); + } + } else if (numSymbols <= 4) { + + // 2 bits per value + final int encodedBufferSize = (int) Math.ceil((double) inSize/4); + encodedBuffer = CompressionUtils.allocateByteBuffer(encodedBufferSize); + int j = -1; + for (int i = 0; i < inSize; i ++) { + if (i % 4 == 0) { + encodedBuffer.put(++j, (byte) 0); + } + encodedBuffer.put(j, (byte) (encodedBuffer.get(j) + (packMappingTable[inBuffer.get(i) & 0xFF] << ((i % 4) * 2)))); + } + } else { + + // 4 bits per value + final int encodedBufferSize = (int) Math.ceil((double)inSize/2); + encodedBuffer = CompressionUtils.allocateByteBuffer(encodedBufferSize); + int j = -1; + for (int i = 0; i < inSize; i ++) { + if (i % 2 == 0) { + encodedBuffer.put(++j, (byte) 0); + } + encodedBuffer.put(j, (byte) (encodedBuffer.get(j) + (packMappingTable[inBuffer.get(i) & 0xFF] << ((i % 2) * 4)))); + } + } + + // write numSymbols + outBuffer.put((byte) numSymbols); + + // write mapping table "packMappingTable" that converts mapped value to original symbol + for(int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i ++) { + if (frequencyTable[i] > 0) { + outBuffer.put((byte) i); + } + } + + // write the length of data + CompressionUtils.writeUint7(encodedBuffer.limit(), outBuffer); + return encodedBuffer; // Here position = 0 since we have always accessed the data buffer using index + } + + public static ByteBuffer decodePack( + final ByteBuffer inBuffer, + final byte[] packMappingTable, + final int numSymbols, + final int uncompressedPackOutputLength) { + final ByteBuffer outBufferPack = CompressionUtils.allocateByteBuffer(uncompressedPackOutputLength); + int j = 0; + if (numSymbols <= 1) { + for (int i=0; i < uncompressedPackOutputLength; i++){ + outBufferPack.put(i, packMappingTable[0]); + } + } + + // 1 bit per value + else if (numSymbols <= 2) { + int v = 0; + for (int i=0; i < uncompressedPackOutputLength; i++){ + if (i % 8 == 0){ + v = inBuffer.get(j++); + } + outBufferPack.put(i, packMappingTable[v & 1]); + v >>=1; + } + } + + // 2 bits per value + else if (numSymbols <= 4){ + int v = 0; + for(int i=0; i < uncompressedPackOutputLength; i++){ + if (i % 4 == 0){ + v = inBuffer.get(j++); + } + outBufferPack.put(i, packMappingTable[v & 3]); + v >>=2; + } + } + + // 4 bits per value + else if (numSymbols <= 16){ + int v = 0; + for(int i=0; i < uncompressedPackOutputLength; i++){ + if (i % 2 == 0){ + v = inBuffer.get(j++); + } + outBufferPack.put(i, packMappingTable[v & 15]); + v >>=4; + } + } + return outBufferPack; + } + + public static ByteBuffer allocateOutputBuffer(final int inSize) { + // This calculation is identical to the one in samtools rANS_static.c + // Presumably the frequency table (always big enough for order 1) = 257*257, + // then * 3 for each entry (byte->symbol, 2 bytes -> scaled frequency), + // + 9 for the header (order byte, and 2 int lengths for compressed/uncompressed lengths). + final int compressedSize = (int) (inSize + 257 * 257 * 3 + 9); + final ByteBuffer outputBuffer = allocateByteBuffer(compressedSize); + if (outputBuffer.remaining() < compressedSize) { + throw new CRAMException("Failed to allocate sufficient buffer size for RANS coder."); + } + return outputBuffer; + } + + // returns a new LITTLE_ENDIAN ByteBuffer of size = bufferSize + public static ByteBuffer allocateByteBuffer(final int bufferSize){ + return ByteBuffer.allocate(bufferSize).order(ByteOrder.LITTLE_ENDIAN); + } + + // returns a LITTLE_ENDIAN ByteBuffer that is created by wrapping a byte[] + public static ByteBuffer wrap(final byte[] inputBytes){ + return ByteBuffer.wrap(inputBytes).order(ByteOrder.LITTLE_ENDIAN); + } + + // returns a LITTLE_ENDIAN ByteBuffer that is created by inputBuffer.slice() + public static ByteBuffer slice(final ByteBuffer inputBuffer){ + return inputBuffer.slice().order(ByteOrder.LITTLE_ENDIAN); + } +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/ExternalCompressor.java b/src/main/java/htsjdk/samtools/cram/compression/ExternalCompressor.java index 4bc70ff46d..5c8f6b34fd 100644 --- a/src/main/java/htsjdk/samtools/cram/compression/ExternalCompressor.java +++ b/src/main/java/htsjdk/samtools/cram/compression/ExternalCompressor.java @@ -1,6 +1,9 @@ package htsjdk.samtools.cram.compression; -import htsjdk.samtools.cram.compression.rans.RANS; +import htsjdk.samtools.cram.compression.range.RangeDecode; +import htsjdk.samtools.cram.compression.range.RangeEncode; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Decode; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Encode; import htsjdk.samtools.cram.structure.block.BlockCompressionMethod; import htsjdk.utils.ValidationUtils; @@ -71,8 +74,13 @@ public static ExternalCompressor getCompressorForMethod( case RANS: return compressorSpecificArg == NO_COMPRESSION_ARG ? - new RANSExternalCompressor(new RANS()) : - new RANSExternalCompressor(compressorSpecificArg, new RANS()); + new RANSExternalCompressor(new RANS4x8Encode(), new RANS4x8Decode()) : + new RANSExternalCompressor(compressorSpecificArg, new RANS4x8Encode(), new RANS4x8Decode()); + + case RANGE: + return compressorSpecificArg == NO_COMPRESSION_ARG ? + new RangeExternalCompressor(new RangeEncode(), new RangeDecode()) : + new RangeExternalCompressor(compressorSpecificArg, new RangeEncode(), new RangeDecode()); case BZIP2: ValidationUtils.validateArg( @@ -85,5 +93,4 @@ public static ExternalCompressor getCompressorForMethod( } } -} - +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/RANSExternalCompressor.java b/src/main/java/htsjdk/samtools/cram/compression/RANSExternalCompressor.java index 24a3f99c7f..dd4794b0e3 100644 --- a/src/main/java/htsjdk/samtools/cram/compression/RANSExternalCompressor.java +++ b/src/main/java/htsjdk/samtools/cram/compression/RANSExternalCompressor.java @@ -24,48 +24,60 @@ */ package htsjdk.samtools.cram.compression; -import htsjdk.samtools.cram.compression.rans.RANS; +import htsjdk.samtools.cram.compression.rans.RANSParams; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Decode; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Encode; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Params; import htsjdk.samtools.cram.structure.block.BlockCompressionMethod; import java.nio.ByteBuffer; import java.util.Objects; public final class RANSExternalCompressor extends ExternalCompressor { - private final RANS.ORDER order; - private final RANS rans; + private final RANSParams.ORDER order; + private final RANS4x8Encode ransEncode; + private final RANS4x8Decode ransDecode; /** * We use a shared RANS instance for all compressors. * @param rans */ - public RANSExternalCompressor(final RANS rans) { - this(RANS.ORDER.ZERO, rans); + public RANSExternalCompressor( + final RANS4x8Encode ransEncode, + final RANS4x8Decode ransDecode) { + this(RANSParams.ORDER.ZERO, ransEncode, ransDecode); } - public RANSExternalCompressor(final int order, final RANS rans) { - this(RANS.ORDER.fromInt(order), rans); + public RANSExternalCompressor( + final int order, + final RANS4x8Encode ransEncode, + final RANS4x8Decode ransDecode) { + this(RANSParams.ORDER.fromInt(order), ransEncode, ransDecode); } - public RANSExternalCompressor(final RANS.ORDER order, final RANS rans) { + public RANSExternalCompressor( + final RANSParams.ORDER order, + final RANS4x8Encode ransEncode, + final RANS4x8Decode ransDecode) { super(BlockCompressionMethod.RANS); - this.rans = rans; + this.ransEncode = ransEncode; + this.ransDecode = ransDecode; this.order = order; } @Override public byte[] compress(final byte[] data) { - final ByteBuffer buffer = rans.compress(ByteBuffer.wrap(data), order); + final RANS4x8Params params = new RANS4x8Params(order); + final ByteBuffer buffer = ransEncode.compress(CompressionUtils.wrap(data), params); return toByteArray(buffer); } @Override public byte[] uncompress(byte[] data) { - final ByteBuffer buf = rans.uncompress(ByteBuffer.wrap(data)); + final ByteBuffer buf = ransDecode.uncompress(CompressionUtils.wrap(data)); return toByteArray(buf); } - public RANS.ORDER getOrder() { return order; } - @Override public String toString() { return String.format("%s(%s)", this.getMethod(), order); @@ -96,4 +108,4 @@ private byte[] toByteArray(final ByteBuffer buffer) { return bytes; } -} +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/RangeExternalCompressor.java b/src/main/java/htsjdk/samtools/cram/compression/RangeExternalCompressor.java new file mode 100644 index 0000000000..650ac7c275 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/RangeExternalCompressor.java @@ -0,0 +1,61 @@ +package htsjdk.samtools.cram.compression; + +import htsjdk.samtools.cram.compression.range.RangeDecode; +import htsjdk.samtools.cram.compression.range.RangeEncode; +import htsjdk.samtools.cram.compression.range.RangeParams; +import htsjdk.samtools.cram.structure.block.BlockCompressionMethod; + +import java.nio.ByteBuffer; + +public class RangeExternalCompressor extends ExternalCompressor{ + + private final int formatFlags; + private final RangeEncode rangeEncode; + private final RangeDecode rangeDecode; + + public RangeExternalCompressor( + final RangeEncode rangeEncode, + final RangeDecode rangeDecode) { + this(0, rangeEncode, rangeDecode); + } + + public RangeExternalCompressor( + final int formatFlags, + final RangeEncode rangeEncode, + final RangeDecode rangeDecode) { + super(BlockCompressionMethod.RANGE); + this.rangeEncode = rangeEncode; + this.rangeDecode = rangeDecode; + this.formatFlags = formatFlags; + } + + @Override + public byte[] compress(byte[] data) { + final RangeParams params = new RangeParams(formatFlags); + final ByteBuffer buffer = rangeEncode.compress(CompressionUtils.wrap(data), params); + return toByteArray(buffer); + } + + @Override + public byte[] uncompress(byte[] data) { + final ByteBuffer buf = rangeDecode.uncompress(CompressionUtils.wrap(data)); + return toByteArray(buf); + } + + @Override + public String toString() { + return String.format("%s(%s)", this.getMethod(),formatFlags); + } + + private byte[] toByteArray(final ByteBuffer buffer) { + if (buffer.hasArray() && buffer.arrayOffset() == 0 && buffer.array().length == buffer.limit()) { + return buffer.array(); + } + + final byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + return bytes; + } + + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/nametokenisation/NameTokenisationDecode.java b/src/main/java/htsjdk/samtools/cram/compression/nametokenisation/NameTokenisationDecode.java new file mode 100644 index 0000000000..61d935aad1 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/nametokenisation/NameTokenisationDecode.java @@ -0,0 +1,164 @@ +package htsjdk.samtools.cram.compression.nametokenisation; + +import htsjdk.samtools.cram.CRAMException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; +import java.util.StringJoiner; + +import static htsjdk.samtools.cram.compression.nametokenisation.TokenStreams.TOKEN_TYPE; +import static htsjdk.samtools.cram.compression.nametokenisation.TokenStreams.TOKEN_STRING; +import static htsjdk.samtools.cram.compression.nametokenisation.TokenStreams.TOKEN_CHAR; +import static htsjdk.samtools.cram.compression.nametokenisation.TokenStreams.TOKEN_DIGITS0; +import static htsjdk.samtools.cram.compression.nametokenisation.TokenStreams.TOKEN_DZLEN; +import static htsjdk.samtools.cram.compression.nametokenisation.TokenStreams.TOKEN_DUP; +import static htsjdk.samtools.cram.compression.nametokenisation.TokenStreams.TOKEN_DIGITS; +import static htsjdk.samtools.cram.compression.nametokenisation.TokenStreams.TOKEN_DELTA; +import static htsjdk.samtools.cram.compression.nametokenisation.TokenStreams.TOKEN_DELTA0; +import static htsjdk.samtools.cram.compression.nametokenisation.TokenStreams.TOKEN_MATCH; +import static htsjdk.samtools.cram.compression.nametokenisation.TokenStreams.TOKEN_END; + +public class NameTokenisationDecode { + + public static String uncompress(final ByteBuffer inBuffer) { + return uncompress(inBuffer, "\n"); + } + + public static String uncompress( + final ByteBuffer inBuffer, + final String separator) { + inBuffer.order(ByteOrder.LITTLE_ENDIAN); + final int uncompressedLength = inBuffer.getInt() & 0xFFFFFFFF; //unused variable. Following the spec + final int numNames = inBuffer.getInt() & 0xFFFFFFFF; + final int useArith = inBuffer.get() & 0xFF; + TokenStreams tokenStreams = new TokenStreams(inBuffer, useArith, numNames); + List> tokensList = new ArrayList<>(numNames); + for(int i = 0; i < numNames; i++) { + tokensList.add(new ArrayList<>()); + } + StringJoiner decodedNamesJoiner = new StringJoiner(separator); + for (int i = 0; i < numNames; i++) { + decodedNamesJoiner.add(decodeSingleName(tokenStreams, tokensList, i)); + } + String uncompressedNames = decodedNamesJoiner.toString(); + if (uncompressedLength == uncompressedNames.length() + separator.length()){ + return uncompressedNames + separator; + } + return uncompressedNames; + } + + private static String decodeSingleName( + final TokenStreams tokenStreams, + final List> tokensList, + final int currentNameIndex) { + + // The information about whether a name is a duplicate or not + // is obtained from the list of tokens at tokenStreams[0,0] + byte nameType = tokenStreams.getTokenStreamByteBuffer(0,TOKEN_TYPE).get(); + final ByteBuffer distBuffer = tokenStreams.getTokenStreamByteBuffer(0,nameType).order(ByteOrder.LITTLE_ENDIAN); + final int dist = distBuffer.getInt() & 0xFFFFFFFF; + final int prevNameIndex = currentNameIndex - dist; + if (nameType == TOKEN_DUP){ + tokensList.add(currentNameIndex, tokensList.get(prevNameIndex)); + return String.join("", tokensList.get(currentNameIndex)); + } + int tokenPosition = 1; // At position 0, we get nameType information + byte type; + StringBuilder decodedNameBuilder = new StringBuilder(); + do { + type = tokenStreams.getTokenStreamByteBuffer(tokenPosition, TOKEN_TYPE).get(); + String currentToken = ""; + switch(type){ + case TOKEN_CHAR: + final char currentTokenChar = (char) tokenStreams.getTokenStreamByteBuffer(tokenPosition, TOKEN_CHAR).get(); + currentToken = String.valueOf(currentTokenChar); + break; + case TOKEN_STRING: + currentToken = readString(tokenStreams.getTokenStreamByteBuffer(tokenPosition, TOKEN_STRING)); + break; + case TOKEN_DIGITS: + currentToken = getDigitsToken(tokenStreams, tokenPosition, TOKEN_DIGITS); + break; + case TOKEN_DIGITS0: + final String digits0Token = getDigitsToken(tokenStreams, tokenPosition, TOKEN_DIGITS0); + final int lenDigits0Token = tokenStreams.getTokenStreamByteBuffer(tokenPosition, TOKEN_DZLEN).get() & 0xFF; + currentToken = leftPadNumber(digits0Token, lenDigits0Token); + break; + case TOKEN_DELTA: + currentToken = getDeltaToken(tokenStreams, tokenPosition, tokensList, prevNameIndex, TOKEN_DELTA); + break; + case TOKEN_DELTA0: + final String delta0Token = getDeltaToken(tokenStreams, tokenPosition, tokensList, prevNameIndex, TOKEN_DELTA0); + final int lenDelta0Token = tokensList.get(prevNameIndex).get(tokenPosition-1).length(); + currentToken = leftPadNumber(delta0Token, lenDelta0Token); + break; + case TOKEN_MATCH: + currentToken = tokensList.get(prevNameIndex).get(tokenPosition-1); + break; + default: + break; + } + tokensList.get(currentNameIndex).add(tokenPosition-1,currentToken); + decodedNameBuilder.append(currentToken); + tokenPosition++; + } while (type!= TOKEN_END); + return decodedNameBuilder.toString(); + } + + private static String getDeltaToken( + final TokenStreams tokenStreams, + final int tokenPosition, + final List> tokensList, + final int prevNameIndex, + final byte tokenType) { + if (!(tokenType == TOKEN_DELTA || tokenType == TOKEN_DELTA0)){ + throw new CRAMException(String.format("Invalid tokenType : %s. " + + "tokenType must be either TOKEN_DELTA or TOKEN_DELTA0", tokenType)); + } + int prevToken; + try { + prevToken = Integer.parseInt(tokensList.get(prevNameIndex).get(tokenPosition -1)); + } catch (final NumberFormatException e) { + final String exceptionMessageSubstring = (tokenType == TOKEN_DELTA) ? "DIGITS or DELTA" : "DIGITS0 or DELTA0"; + throw new CRAMException(String.format("The token in the prior name must be of type %s", + exceptionMessageSubstring), e); + } + final int deltaTokenValue = tokenStreams.getTokenStreamByteBuffer(tokenPosition,tokenType).get() & 0xFF; + return Long.toString(prevToken + deltaTokenValue); + } + + private static String getDigitsToken( + final TokenStreams tokenStreams, + final int tokenPosition, + final byte tokenType ) { + if (!(tokenType == TOKEN_DIGITS || tokenType == TOKEN_DIGITS0)){ + throw new CRAMException(String.format("Invalid tokenType : %s. " + + "tokenType must be either TOKEN_DIGITS or TOKEN_DIGITS0", tokenType)); + } + final ByteBuffer digitsByteBuffer = tokenStreams.getTokenStreamByteBuffer(tokenPosition, tokenType).order(ByteOrder.LITTLE_ENDIAN); + final long digits = digitsByteBuffer.getInt() & 0xFFFFFFFFL; + return Long.toString(digits); + } + + private static String readString(final ByteBuffer inputBuffer) { + // spec: We fetch one byte at a time from the value byte stream, + // appending to the name buffer until the byte retrieved is zero. + StringBuilder resultStringBuilder = new StringBuilder(); + byte currentByte = inputBuffer.get(); + while (currentByte != 0) { + resultStringBuilder.append((char) currentByte); + currentByte = inputBuffer.get(); + } + return resultStringBuilder.toString(); + } + + private static String leftPadNumber(String value, final int len) { + // return value such that it is at least len bytes long with leading zeros + while (value.length() < len) { + value = "0" + value; + } + return value; + } + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/nametokenisation/NameTokenisationEncode.java b/src/main/java/htsjdk/samtools/cram/compression/nametokenisation/NameTokenisationEncode.java new file mode 100644 index 0000000000..4a07f8422f --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/nametokenisation/NameTokenisationEncode.java @@ -0,0 +1,287 @@ +package htsjdk.samtools.cram.compression.nametokenisation; + +import htsjdk.samtools.cram.compression.CompressionUtils; +import htsjdk.samtools.cram.compression.nametokenisation.tokens.EncodeToken; +import htsjdk.samtools.cram.compression.range.RangeEncode; +import htsjdk.samtools.cram.compression.range.RangeParams; +import htsjdk.samtools.cram.compression.rans.RANSEncode; +import htsjdk.samtools.cram.compression.rans.ransnx16.RANSNx16Encode; +import htsjdk.samtools.cram.compression.rans.ransnx16.RANSNx16Params; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class NameTokenisationEncode { + + private int maxToken; + private int maxLength; + + public ByteBuffer compress(final ByteBuffer inBuffer){ + return compress(inBuffer, 0); + } + + public ByteBuffer compress(final ByteBuffer inBuffer, final int useArith){ + maxToken = 0; + maxLength = 0; + ArrayList names = new ArrayList<>(); + int lastPosition = inBuffer.position(); + + // convert buffer to array of names + while(inBuffer.hasRemaining()){ + byte currentByte = inBuffer.get(); + if ((currentByte) == '\n' || inBuffer.position()==inBuffer.limit()){ + int length = inBuffer.position() - lastPosition; + byte[] bytes = new byte[length]; + inBuffer.position(lastPosition); + inBuffer.get(bytes, 0, length); + names.add(new String(bytes, StandardCharsets.UTF_8).trim()); + lastPosition = inBuffer.position(); + } + } + + final int numNames = names.size(); + // guess max size -> str.length*2 + 10000 (from htscodecs javascript code) + ByteBuffer outBuffer = allocateOutputBuffer((inBuffer.limit()*2)+10000); + outBuffer.putInt(inBuffer.limit()); + outBuffer.putInt(numNames); + outBuffer.put((byte)useArith); + + // Instead of List> for tokensList like we did in Decoder, we use List> + // as we also need to store the TOKEN_TYPE, relative value when compared to prev name's token + // along with the token value. + List> tokensList = new ArrayList<>(numNames); + HashMap nameIndexMap = new HashMap<>(); + int[] tokenFrequencies = new int[256]; + for(int nameIndex = 0; nameIndex < numNames; nameIndex++) { + tokeniseName(tokensList, nameIndexMap, tokenFrequencies, names.get(nameIndex), nameIndex); + } + for (int tokenPosition = 0; tokenPosition < maxToken; tokenPosition++) { + List tokenStream = new ArrayList(TokenStreams.TOTAL_TOKEN_TYPES); + for (int i = 0; i < TokenStreams.TOTAL_TOKEN_TYPES; i++) { + tokenStream.add(ByteBuffer.allocate(numNames* maxLength).order(ByteOrder.LITTLE_ENDIAN)); + } + fillByteStreams(tokenStream,tokensList,tokenPosition,numNames); + serializeByteStreams(tokenStream,useArith,outBuffer); + } + + // sets limit to current position and position to '0' + outBuffer.flip(); + return outBuffer; + } + + private void tokeniseName(final List> tokensList, + HashMap nameIndexMap, + int[] tokenFrequencies, + final String name, + final int currentNameIndex) { + int currMaxLength = 0; + + // always compare against last name only + final int prevNameIndex = currentNameIndex - 1; + tokensList.add(new ArrayList<>()); + if (nameIndexMap.containsKey(name)) { + // TODO: Add Test to cover this code + tokensList.get(currentNameIndex).add(new EncodeToken(String.valueOf(currentNameIndex - nameIndexMap.get(name)), String.valueOf(currentNameIndex - nameIndexMap.get(name)),TokenStreams.TOKEN_DUP)); + } else { + tokensList.get(currentNameIndex).add(new EncodeToken(String.valueOf(currentNameIndex == 0 ? 0 : 1),String.valueOf(currentNameIndex == 0 ? 0 : 1),TokenStreams.TOKEN_DIFF)); + } + // Get the list of tokens `tok` for the current name + nameIndexMap.put(name, currentNameIndex); + String regex = "([a-zA-Z0-9]{1,9})|([^a-zA-Z0-9]+)"; + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(name); + List tok = new ArrayList<>(); + while (matcher.find()) { + tok.add(matcher.group()); + } + for (int i = 0; i < tok.size(); i++) { + // In the list of tokens, all the tokens are offset by 1 + // because at position "0", we have a token that provides info if the name is a DIFF or DUP + // token 0 = DIFF vs DUP + int tokenIndex = i + 1; + byte type = TokenStreams.TOKEN_STRING; + String str = tok.get(i); // absolute value of the token + String val = tok.get(i); // relative value of the token (comparing to prevname's token at the same token position) + if (tok.get(i).matches("^0+[0-9]*$")) { + type = TokenStreams.TOKEN_DIGITS0; + } else if (tok.get(i).matches("^[0-9]+$")) { + type = TokenStreams.TOKEN_DIGITS; + } else if (tok.get(i).length() == 1) { + type = TokenStreams.TOKEN_CHAR; + } + + // compare the current token with token from the previous name at the current token's index + // if there exists a previous name and a token at the corresponding index of the previous name + if (prevNameIndex >=0 && tokensList.get(prevNameIndex).size() > tokenIndex) { + EncodeToken prevToken = tokensList.get(prevNameIndex).get(tokenIndex); + if (prevToken.getActualTokenValue().equals(tok.get(i))) { + type = TokenStreams.TOKEN_MATCH; + val = ""; + } else if (type==TokenStreams.TOKEN_DIGITS + && (prevToken.getTokenType() == TokenStreams.TOKEN_DIGITS || prevToken.getTokenType() == TokenStreams.TOKEN_DELTA)) { + int v = Integer.parseInt(val); + int s = Integer.parseInt(prevToken.getActualTokenValue()); + int d = v - s; + tokenFrequencies[tokenIndex]++; + if (d >= 0 && d < 256 && tokenFrequencies[tokenIndex] > currentNameIndex / 2) { + type = TokenStreams.TOKEN_DELTA; + val = String.valueOf(d); + } + } else if (type==TokenStreams.TOKEN_DIGITS0 && prevToken.getActualTokenValue().length() == val.length() + && (prevToken.getTokenType() == TokenStreams.TOKEN_DIGITS0 || prevToken.getTokenType() == TokenStreams.TOKEN_DELTA0)) { + int d = Integer.parseInt(val) - Integer.parseInt(prevToken.getActualTokenValue()); + tokenFrequencies[tokenIndex]++; + if (d >= 0 && d < 256 && tokenFrequencies[tokenIndex] > currentNameIndex / 2) { + type = TokenStreams.TOKEN_DELTA0; + val = String.valueOf(d); + } + } + } + tokensList.get(currentNameIndex).add(new EncodeToken(str, val, type)); + + if (currMaxLength < val.length() + 3) { + // TODO: check this? Why isn't unint32 case handled? + // +3 for integers; 5 -> (Uint32)5 (from htscodecs javascript code) + currMaxLength = val.length() + 3; + } + } + + tokensList.get(currentNameIndex).add(new EncodeToken("","",TokenStreams.TOKEN_END)); + final int currMaxToken = tokensList.get(currentNameIndex).size(); + if (maxToken < currMaxToken) + maxToken = currMaxToken; + if (maxLength < currMaxLength) + maxLength = currMaxLength; + } + + public void fillByteStreams( + final List tokenStream, + final List> tokensList, + final int tokenPosition, + final int numNames) { + + // Fill tokenStreams object using tokensList + for (int nameIndex = 0; nameIndex < numNames; nameIndex++) { + if (tokenPosition > 0 && tokensList.get(nameIndex).get(0).getTokenType() == TokenStreams.TOKEN_DUP) { + continue; + } + if (tokensList.get(nameIndex).size() <= tokenPosition) { + continue; + } + EncodeToken encodeToken = tokensList.get(nameIndex).get(tokenPosition); + byte type = encodeToken.getTokenType(); + tokenStream.get(TokenStreams.TOKEN_TYPE).put(type); + switch (type) { + case TokenStreams.TOKEN_DIFF: + tokenStream.get(TokenStreams.TOKEN_DIFF).putInt(Integer.parseInt(encodeToken.getRelativeTokenValue())); + break; + + case TokenStreams.TOKEN_DUP: + tokenStream.get(TokenStreams.TOKEN_DUP).putInt(Integer.parseInt(encodeToken.getRelativeTokenValue())); + break; + + case TokenStreams.TOKEN_STRING: + writeString(tokenStream.get(TokenStreams.TOKEN_STRING),encodeToken.getRelativeTokenValue()); + break; + + case TokenStreams.TOKEN_CHAR: + tokenStream.get(TokenStreams.TOKEN_CHAR).put(encodeToken.getRelativeTokenValue().getBytes()[0]); + break; + + case TokenStreams.TOKEN_DIGITS: + tokenStream.get(TokenStreams.TOKEN_DIGITS).putInt(Integer.parseInt(encodeToken.getRelativeTokenValue())); + break; + + case TokenStreams.TOKEN_DIGITS0: + tokenStream.get(TokenStreams.TOKEN_DIGITS0).putInt(Integer.parseInt(encodeToken.getRelativeTokenValue())); + tokenStream.get(TokenStreams.TOKEN_DZLEN).put((byte) encodeToken.getRelativeTokenValue().length()); + break; + + case TokenStreams.TOKEN_DELTA: + tokenStream.get(TokenStreams.TOKEN_DELTA).put((byte)Integer.parseInt(encodeToken.getRelativeTokenValue())); + break; + + case TokenStreams.TOKEN_DELTA0: + tokenStream.get(TokenStreams.TOKEN_DELTA0).put((byte)Integer.parseInt(encodeToken.getRelativeTokenValue())); + break; + } + } + } + + private static void writeString(final ByteBuffer tokenStreamBuffer, final String val) { + byte[] bytes = val.getBytes(); + tokenStreamBuffer.put(bytes); + tokenStreamBuffer.put((byte) 0); + } + + public static ByteBuffer tryCompress(final ByteBuffer src, final int useArith) { + // compress with different formatFlags + // and return the compressed output ByteBuffer with the least number of bytes + int bestcompressedByteLength = 1 << 30; + ByteBuffer compressedByteBuffer = null; + int[] formatFlagsList = {0, 1, 64, 65, 128, 129, 193+8}; + for (int formatFlags : formatFlagsList) { + if ((formatFlags & 1) != 0 && src.remaining() < 100) + continue; + + if ((formatFlags & 8) != 0 && (src.remaining() % 4) != 0) + continue; + + ByteBuffer tmpByteBuffer = null; + try { + if (useArith!=0) { + // Encode using Range + RangeEncode rangeEncode = new RangeEncode(); + src.rewind(); + tmpByteBuffer = rangeEncode.compress(src,new RangeParams(formatFlags)); + + } else { + // Encode using RANS + RANSEncode ransEncode = new RANSNx16Encode(); + src.rewind(); + tmpByteBuffer = ransEncode.compress(src, new RANSNx16Params(formatFlags)); + } + } catch (final Exception ignored) {} + if (tmpByteBuffer != null && bestcompressedByteLength > tmpByteBuffer.remaining()) { + bestcompressedByteLength = tmpByteBuffer.remaining(); + compressedByteBuffer = tmpByteBuffer; + } + } + return compressedByteBuffer; + } + + protected void serializeByteStreams( + final List tokenStream, + final int useArith, + final ByteBuffer outBuffer) { + + // Compress and serialise tokenStreams + for (int tokenType = 0; tokenType <= TokenStreams.TOKEN_END; tokenType++) { + if (tokenStream.get(tokenType).remaining() > 0) { + outBuffer.put((byte) (tokenType + ((tokenType == 0) ? 128 : 0))); + ByteBuffer tempOutByteBuffer = tryCompress(tokenStream.get(tokenType), useArith); + CompressionUtils.writeUint7(tempOutByteBuffer.limit(),outBuffer); + outBuffer.put(tempOutByteBuffer); + } + } + } + + protected ByteBuffer allocateOutputBuffer(final int inSize) { + + // same as the allocateOutputBuffer in RANS4x8Encode and RANSNx16Encode + // TODO: de-duplicate + final int compressedSize = (int) (1.05 * inSize + 257 * 257 * 3 + 9); + final ByteBuffer outputBuffer = ByteBuffer.allocate(compressedSize); + if (outputBuffer.remaining() < compressedSize) { + throw new RuntimeException("Failed to allocate sufficient buffer size for Range coder."); + } + outputBuffer.order(ByteOrder.LITTLE_ENDIAN); + return outputBuffer; + } +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/nametokenisation/TokenStreams.java b/src/main/java/htsjdk/samtools/cram/compression/nametokenisation/TokenStreams.java new file mode 100644 index 0000000000..deed459022 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/nametokenisation/TokenStreams.java @@ -0,0 +1,125 @@ +package htsjdk.samtools.cram.compression.nametokenisation; + +import htsjdk.samtools.cram.CRAMException; +import htsjdk.samtools.cram.compression.CompressionUtils; +import htsjdk.samtools.cram.compression.range.RangeDecode; +import htsjdk.samtools.cram.compression.rans.RANSDecode; +import htsjdk.samtools.cram.compression.rans.ransnx16.RANSNx16Decode; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +public class TokenStreams { + + public static final byte TOKEN_TYPE = 0x00; + public static final byte TOKEN_STRING = 0x01; + public static final byte TOKEN_CHAR = 0x02; + public static final byte TOKEN_DIGITS0 = 0x03; + public static final byte TOKEN_DZLEN = 0x04; + public static final byte TOKEN_DUP = 0x05; + public static final byte TOKEN_DIFF = 0x06; + public static final byte TOKEN_DIGITS = 0x07; + public static final byte TOKEN_DELTA = 0x08; + public static final byte TOKEN_DELTA0 = 0x09; + public static final byte TOKEN_MATCH = 0x0A; + public static final byte TOKEN_END = 0x0C; + public static final int TOTAL_TOKEN_TYPES = 13; + + private static final int NEW_TOKEN_FLAG_MASK = 0x80; + private static final int DUP_TOKEN_FLAG_MASK = 0x40; + private static final int TYPE_TOKEN_FLAG_MASK = 0x3F; + + private final List> tokenStreams; + + public TokenStreams() { + tokenStreams = new ArrayList<>(TOTAL_TOKEN_TYPES); + for (int i = 0; i < TOTAL_TOKEN_TYPES; i++) { + tokenStreams.add(new ArrayList<>()); + } + } + + public TokenStreams(final ByteBuffer inputByteBuffer, final int useArith, final int numNames) { + // The outer index corresponds to type of the token + // and the inner index corresponds to the position of the token in a name (starting at index 1) + // Each element in this list of lists is a Token (ie, a ByteBuffer) + + // TokenStreams[type = TOKEN_TYPE(0x00), pos = 0] contains a ByteBuffer of length = number of names + // This ByteBuffer helps determine if each of the names is a TOKEN_DUP or TOKEN_DIFF + // when compared with the previous name + + // TokenStreams[type = TOKEN_TYPE(0x00), pos = all except 0] + // contains a ByteBuffer of length = number of names + // This ByteBuffer helps determine the type of each of the token at the specicfied pos + + this(); + int tokenPosition = -1; + while (inputByteBuffer.hasRemaining()) { + final byte tokenTypeFlags = inputByteBuffer.get(); + final boolean isNewToken = ((tokenTypeFlags & NEW_TOKEN_FLAG_MASK) != 0); + final boolean isDupToken = ((tokenTypeFlags & DUP_TOKEN_FLAG_MASK) != 0); + final int tokenType = (tokenTypeFlags & TYPE_TOKEN_FLAG_MASK); + if (tokenType < 0 || tokenType > TOKEN_END) { + throw new CRAMException("Invalid Token tokenType: " + tokenType); + } + if (isNewToken) { + tokenPosition++; + if (tokenPosition > 0) { + // If newToken and not the first newToken + // Ensure that the size of tokenStream for each type of token = tokenPosition + // by adding an empty ByteBuffer if needed + for (int i = 0; i < TOTAL_TOKEN_TYPES; i++) { + final List currTokenStream = tokenStreams.get(i); + if (currTokenStream.size() < tokenPosition) { + currTokenStream.add(ByteBuffer.allocate(0)); + } + if (currTokenStream.size() < tokenPosition) { + throw new CRAMException("TokenStream is missing Token(s) at Token Type: " + i); + } + } + } + } + if ((isNewToken) && (tokenType != TOKEN_TYPE)) { + + // Spec: if we have a byte stream B5,DIGIT S but no B5,T Y P E + // then we assume the contents of B5,T Y P E consist of one DIGITS tokenType + // followed by as many MATCH types as are needed. + final ByteBuffer typeDataByteBuffer = ByteBuffer.allocate(numNames); + for (int i = 0; i < numNames; i++) { + typeDataByteBuffer.put((byte) TOKEN_MATCH); + } + typeDataByteBuffer.rewind(); + typeDataByteBuffer.put(0, (byte) tokenType); + tokenStreams.get(0).add(typeDataByteBuffer); + } + if (isDupToken) { + final int dupPosition = inputByteBuffer.get() & 0xFF; + final int dupType = inputByteBuffer.get() & 0xFF; + final ByteBuffer dupTokenStream = tokenStreams.get(dupType).get(dupPosition).duplicate(); + tokenStreams.get(tokenType).add(tokenPosition,dupTokenStream); + } else { + final int clen = CompressionUtils.readUint7(inputByteBuffer); + final byte[] dataBytes = new byte[clen]; + inputByteBuffer.get(dataBytes, 0, clen); // offset in the dst byte array + final ByteBuffer uncompressedDataByteBuffer; + if (useArith != 0) { + RangeDecode rangeDecode = new RangeDecode(); + uncompressedDataByteBuffer = rangeDecode.uncompress(ByteBuffer.wrap(dataBytes)); + + } else { + RANSDecode ransdecode = new RANSNx16Decode(); + uncompressedDataByteBuffer = ransdecode.uncompress(ByteBuffer.wrap(dataBytes)); + } + this.getTokenStreamByType(tokenType).add(tokenPosition,uncompressedDataByteBuffer); + } + } + } + + public List getTokenStreamByType(final int tokenType) { + return tokenStreams.get(tokenType); + } + + public ByteBuffer getTokenStreamByteBuffer(final int tokenPosition, final int tokenType) { + return tokenStreams.get(tokenType).get(tokenPosition); + } +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/nametokenisation/tokens/EncodeToken.java b/src/main/java/htsjdk/samtools/cram/compression/nametokenisation/tokens/EncodeToken.java new file mode 100644 index 0000000000..4e7cb0288a --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/nametokenisation/tokens/EncodeToken.java @@ -0,0 +1,38 @@ +package htsjdk.samtools.cram.compression.nametokenisation.tokens; + +public class EncodeToken { + + private String actualTokenValue; + private String relativeTokenValue; + private byte tokenType; + + public EncodeToken(String str, String val, byte type) { + this.actualTokenValue = str; + this.relativeTokenValue = val; + this.tokenType = type; + } + + public String getActualTokenValue() { + return actualTokenValue; + } + + public void setActualTokenValue(String actualTokenValue) { + this.actualTokenValue = actualTokenValue; + } + + public String getRelativeTokenValue() { + return relativeTokenValue; + } + + public void setRelativeTokenValue(String relativeTokenValue) { + this.relativeTokenValue = relativeTokenValue; + } + + public byte getTokenType() { + return tokenType; + } + + public void setTokenType(byte tokenType) { + this.tokenType = tokenType; + } +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/range/ByteModel.java b/src/main/java/htsjdk/samtools/cram/compression/range/ByteModel.java new file mode 100644 index 0000000000..f2f71c4e2a --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/range/ByteModel.java @@ -0,0 +1,107 @@ +package htsjdk.samtools.cram.compression.range; + +import java.nio.ByteBuffer; + +public class ByteModel { + // spec: To encode any symbol the entropy encoder needs to know + // the frequency of the symbol to encode, + // the cumulative frequencies of all symbols prior to this symbol, + // and the total of all frequencies. + public int totalFrequency; + public final int maxSymbol; + public final int[] symbols; + public final int[] frequencies; + + public ByteModel(final int numSymbols) { + // Spec: ModelCreate method + this.totalFrequency = numSymbols; + this.maxSymbol = numSymbols - 1; + frequencies = new int[maxSymbol+1]; + symbols = new int[maxSymbol+1]; + for (int i = 0; i <= maxSymbol; i++) { + this.symbols[i] = i; + this.frequencies[i] = 1; + } + } + + public int modelDecode(final ByteBuffer inBuffer, final RangeCoder rangeCoder){ + + // decodes one symbol + final int freq = rangeCoder.rangeGetFrequency(totalFrequency); + int cumulativeFrequency = 0; + int x = 0; + while (cumulativeFrequency + frequencies[x] <= freq){ + cumulativeFrequency += frequencies[x++]; + } + + // update rangecoder + rangeCoder.rangeDecode(inBuffer,cumulativeFrequency,frequencies[x]); + + // update model frequencies + frequencies[x] += Constants.STEP; + totalFrequency += Constants.STEP; + if (totalFrequency > Constants.MAX_FREQ){ + // if totalFrequency is too high, the frequencies are halved, making + // sure to avoid any zero frequencies being created. + modelRenormalize(); + } + + // keep symbols approximately frequency sorted + final int symbol = symbols[x]; + if (x > 0 && frequencies[x] > frequencies[x-1]){ + // Swap frequencies[x], frequencies[x-1] + int tmp = frequencies[x]; + frequencies[x] = frequencies[x-1]; + frequencies[x-1] = tmp; + + // Swap symbols[x], symbols[x-1] + tmp = symbols[x]; + symbols[x] = symbols[x-1]; + symbols[x-1] = tmp; + } + return symbol; + } + + public void modelRenormalize(){ + // frequencies are halved + totalFrequency = 0; + for (int i=0; i <= maxSymbol; i++){ + frequencies[i] -= Math.floorDiv(frequencies[i],2); + totalFrequency += frequencies[i]; + } + } + + public void modelEncode(final ByteBuffer outBuffer, final RangeCoder rangeCoder, final int symbol){ + + // encodes one input symbol + int cumulativeFrequency = 0; + int i; + for( i = 0; symbols[i] != symbol; i++){ + cumulativeFrequency += frequencies[i]; + } + + // Encode + rangeCoder.rangeEncode(outBuffer, cumulativeFrequency, frequencies[i],totalFrequency); + + // Update Model + frequencies[i] += Constants.STEP; + totalFrequency += Constants.STEP; + if (totalFrequency > Constants.MAX_FREQ){ + modelRenormalize(); + } + + // Keep symbols approximately frequency sorted (ascending order) + if (i > 0 && frequencies[i] > frequencies[i-1]){ + // swap frequencies + int tmp = frequencies[i]; + frequencies[i] = frequencies[i-1]; + frequencies[i-1]=tmp; + + // swap symbols + tmp = symbols[i]; + symbols[i] = symbols[i-1]; + symbols[i-1] = tmp; + } + } + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/range/Constants.java b/src/main/java/htsjdk/samtools/cram/compression/range/Constants.java new file mode 100644 index 0000000000..e2e941a549 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/range/Constants.java @@ -0,0 +1,8 @@ +package htsjdk.samtools.cram.compression.range; + +final public class Constants { + public static final int NUMBER_OF_SYMBOLS = 256; + public static final int MAX_FREQ = ((1<<16)-17); + public static final int STEP = 16; + public static final long MAX_RANGE = 0xFFFFFFFFL; +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/range/RangeCoder.java b/src/main/java/htsjdk/samtools/cram/compression/range/RangeCoder.java new file mode 100644 index 0000000000..a7d7b21828 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/range/RangeCoder.java @@ -0,0 +1,105 @@ +package htsjdk.samtools.cram.compression.range; + +import java.nio.ByteBuffer; + +public class RangeCoder { + + private long low; + private long range; + private long code; + private int FFnum; + private boolean carry; + private int cache; + + protected RangeCoder() { + // Spec: RangeEncodeStart + this.low = 0; + this.range = Constants.MAX_RANGE; // 4 bytes of all 1's + this.code = 0; + this.FFnum = 0; + this.carry = false; + this.cache = 0; + } + + protected void rangeDecodeStart(final ByteBuffer inBuffer){ + for (int i = 0; i < 5; i++){ + code = (code << 8) + (inBuffer.get() & 0xFF); + } + code &= Constants.MAX_RANGE; + } + + protected void rangeDecode(final ByteBuffer inBuffer, final int cumulativeFrequency, final int symbolFrequency){ + code -= cumulativeFrequency * range; + range *= symbolFrequency; + + while (range < (1<<24)) { + range <<= 8; + code = (code << 8) + (inBuffer.get() & 0xFF); // Ensure code is positive + } + } + + protected int rangeGetFrequency(final int totalFrequency){ + range = (long) Math.floor(range / totalFrequency); + return (int) Math.floor(code / range); + } + + protected void rangeEncode( + final ByteBuffer outBuffer, + final int cumulativeFrequency, + final int symbolFrequency, + final int totalFrequency){ + final long old_low = low; + range = (long) Math.floor(range/totalFrequency); + low += cumulativeFrequency * range; + low &= 0xFFFFFFFFL; // keep bottom 4 bytes, shift the top byte out of low + range *= symbolFrequency; + + if (low < old_low) { + carry = true; + } + + // Renormalise if range gets too small + while (range < (1<<24)) { + range <<= 8; + rangeShiftLow(outBuffer); + } + + } + + protected void rangeEncodeEnd(final ByteBuffer outBuffer){ + for(int i = 0; i < 5; i++){ + rangeShiftLow(outBuffer); + } + } + + private void rangeShiftLow(final ByteBuffer outBuffer) { + // rangeShiftLow tracks the total number of extra bytes to emit and + // carry indicates whether they are a string of 0xFF or 0x00 values + + // range must be less than (2^24) or (1<<24) or (0x1000000) + // "cache" holds the top byte that will be flushed to the output + + if ((low < 0xff000000L) || carry) { //TODO: 0xff000000L make this magic number a constant + if (carry == false) { + outBuffer.put((byte) cache); + while (FFnum > 0) { + outBuffer.put((byte) 0xFF); + FFnum--; + } + } else { + outBuffer.put((byte) (cache + 1)); + while (FFnum > 0) { + outBuffer.put((byte) 0x00); + FFnum--; + } + + } + cache = (int) (low >>> 24); // Copy of top byte ready for next flush + carry = false; + } else { + FFnum++; + } + low = low<<8 & (0xFFFFFFFFL); // force low to be +ve + } + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/range/RangeDecode.java b/src/main/java/htsjdk/samtools/cram/compression/range/RangeDecode.java new file mode 100644 index 0000000000..5987630170 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/range/RangeDecode.java @@ -0,0 +1,259 @@ +package htsjdk.samtools.cram.compression.range; + +import htsjdk.samtools.cram.CRAMException; +import htsjdk.samtools.cram.compression.BZIP2ExternalCompressor; +import htsjdk.samtools.cram.compression.CompressionUtils; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +public class RangeDecode { + + private static final ByteBuffer EMPTY_BUFFER = CompressionUtils.allocateByteBuffer(0); + + // This method assumes that inBuffer is already rewound. + // It uncompresses the data in the inBuffer, leaving it consumed. + // Returns a rewound ByteBuffer containing the uncompressed data. + public ByteBuffer uncompress(final ByteBuffer inBuffer) { + + // For Range decoding, the bytes are read in little endian from the input stream + inBuffer.order(ByteOrder.LITTLE_ENDIAN); + return uncompress(inBuffer, 0); + } + + private ByteBuffer uncompress(final ByteBuffer inBuffer, final int outSize) { + if (inBuffer.remaining() == 0) { + return EMPTY_BUFFER; + } + + // the first byte of compressed stream gives the formatFlags + final int formatFlags = inBuffer.get() & 0xFF; + final RangeParams rangeParams = new RangeParams(formatFlags); + + // noSz + int uncompressedSize = rangeParams.isNosz() ? outSize : CompressionUtils.readUint7(inBuffer); + + // stripe + if (rangeParams.isStripe()) { + return decodeStripe(inBuffer, uncompressedSize); + } + + // pack + // if pack, get pack metadata, which will be used later to decode packed data + int packDataLength = 0; + int numSymbols = 0; + byte[] packMappingTable = null; + if (rangeParams.isPack()){ + packDataLength = uncompressedSize; + numSymbols = inBuffer.get() & 0xFF; + + // if (numSymbols > 16 or numSymbols==0), raise exception + if (numSymbols <= 16 && numSymbols!=0) { + packMappingTable = new byte[numSymbols]; + for (int i = 0; i < numSymbols; i++) { + packMappingTable[i] = inBuffer.get(); + } + uncompressedSize = CompressionUtils.readUint7(inBuffer); + } else { + throw new CRAMException("Bit Packing is not permitted when number of distinct symbols is greater than 16 or equal to 0. " + + "Number of distinct symbols: " + numSymbols); + } + } + + ByteBuffer outBuffer; + if (rangeParams.isCAT()){ + outBuffer = CompressionUtils.slice(inBuffer); + outBuffer.limit(uncompressedSize); + // While resetting the position to the end is not strictly necessary, + // it is being done for the sake of completeness and + // to meet the requirements of the tests that verify the boundary conditions. + inBuffer.position(inBuffer.position()+uncompressedSize); + } else if (rangeParams.isExternalCompression()){ + final byte[] extCompressedBytes = new byte[inBuffer.remaining()]; + int extCompressedBytesIdx = 0; + final int start = inBuffer.position(); + final int end = inBuffer.limit(); + for (int i = start; i < end; i++) { + extCompressedBytes[extCompressedBytesIdx] = inBuffer.get(); + extCompressedBytesIdx++; + } + outBuffer = uncompressEXT(extCompressedBytes); + } else if (rangeParams.isRLE()){ + outBuffer = CompressionUtils.allocateByteBuffer(uncompressedSize); + switch (rangeParams.getOrder()) { + case ZERO: + uncompressRLEOrder0(inBuffer, outBuffer, uncompressedSize); + break; + case ONE: + uncompressRLEOrder1(inBuffer, outBuffer, uncompressedSize); + break; + } + } else { + outBuffer = CompressionUtils.allocateByteBuffer(uncompressedSize); + switch (rangeParams.getOrder()){ + case ZERO: + uncompressOrder0(inBuffer, outBuffer, uncompressedSize); + break; + case ONE: + uncompressOrder1(inBuffer, outBuffer, uncompressedSize); + break; + } + } + + // if pack, then decodePack + if (rangeParams.isPack()) { + outBuffer = CompressionUtils.decodePack(outBuffer, packMappingTable, numSymbols, packDataLength); + } + outBuffer.rewind(); + return outBuffer; + + } + + private void uncompressOrder0( + final ByteBuffer inBuffer, + final ByteBuffer outBuffer, + final int outSize) { + + int maxSymbols = inBuffer.get() & 0xFF; + maxSymbols = maxSymbols==0 ? 256 : maxSymbols; + + final ByteModel byteModel = new ByteModel(maxSymbols); + final RangeCoder rangeCoder = new RangeCoder(); + rangeCoder.rangeDecodeStart(inBuffer); + + for (int i = 0; i < outSize; i++) { + outBuffer.put(i, (byte) byteModel.modelDecode(inBuffer, rangeCoder)); + } + } + + private void uncompressOrder1( + final ByteBuffer inBuffer, + final ByteBuffer outBuffer, + final int outSize) { + + int maxSymbols = inBuffer.get() & 0xFF; + maxSymbols = maxSymbols==0 ? 256 : maxSymbols; + final List byteModelList = new ArrayList(maxSymbols); + for(int i=0;i byteModelRunsList = new ArrayList(258); + for (int i=0; i <=257; i++){ + byteModelRunsList.add(i, new ByteModel(4)); + } + RangeCoder rangeCoder = new RangeCoder(); + rangeCoder.rangeDecodeStart(inBuffer); + + int i = 0; + while (i < outSize) { + outBuffer.put(i,(byte) modelLit.modelDecode(inBuffer, rangeCoder)); + final int last = outBuffer.get(i) & (0xFF); + int part = byteModelRunsList.get(last).modelDecode(inBuffer,rangeCoder); + int run = part; + int rctx = 256; + while (part == 3) { + part = byteModelRunsList.get(rctx).modelDecode(inBuffer, rangeCoder); + rctx = 257; + run += part; + } + for (int j = 1; j <= run; j++){ + outBuffer.put(i+j, (byte) last); + } + i += run+1; + } + } + + private void uncompressRLEOrder1( + final ByteBuffer inBuffer, + final ByteBuffer outBuffer, + final int outSize) { + + int maxSymbols = inBuffer.get() & 0xFF; + maxSymbols = maxSymbols == 0 ? 256 : maxSymbols; + final List byteModelLitList = new ArrayList(maxSymbols); + for (int i=0; i < maxSymbols; i++) { + byteModelLitList.add(i,new ByteModel(maxSymbols)); + } + final List byteModelRunsList = new ArrayList(258); + for (int i=0; i <=257; i++){ + byteModelRunsList.add(i, new ByteModel(4)); + } + + final RangeCoder rangeCoder = new RangeCoder(); + rangeCoder.rangeDecodeStart(inBuffer); + + int last = 0; + int i = 0; + while (i < outSize) { + outBuffer.put(i,(byte) byteModelLitList.get(last).modelDecode(inBuffer, rangeCoder)); + last = outBuffer.get(i) & 0xFF; + int part = byteModelRunsList.get(last).modelDecode(inBuffer,rangeCoder); + int run = part; + int rctx = 256; + while (part == 3) { + part = byteModelRunsList.get(rctx).modelDecode(inBuffer, rangeCoder); + rctx = 257; + run += part; + } + for (int j = 1; j <= run; j++){ + outBuffer.put(i+j, (byte)last); + } + i += run+1; + } + } + + private ByteBuffer uncompressEXT(final byte[] extCompressedBytes) { + final BZIP2ExternalCompressor compressor = new BZIP2ExternalCompressor(); + final byte [] extUncompressedBytes = compressor.uncompress(extCompressedBytes); + return CompressionUtils.wrap(extUncompressedBytes); + } + + private ByteBuffer decodeStripe(final ByteBuffer inBuffer, final int outSize){ + final int numInterleaveStreams = inBuffer.get() & 0xFF; + + // read lengths of compressed interleaved streams + for ( int j=0; j j){ + uncompressedLengths[j]++; + } + + transposedData[j] = uncompress(inBuffer, uncompressedLengths[j]); + } + + // Transpose + final ByteBuffer outBuffer = CompressionUtils.allocateByteBuffer(outSize); + for (int j = 0; j { + + private static final ByteBuffer EMPTY_BUFFER = CompressionUtils.allocateByteBuffer(0); + + // This method assumes that inBuffer is already rewound. + // It compresses the data in the inBuffer, leaving it consumed. + // Returns a rewound ByteBuffer containing the compressed data. + public ByteBuffer compress(final ByteBuffer inBuffer, final RangeParams rangeParams) { + if (inBuffer.remaining() == 0) { + return EMPTY_BUFFER; + } + + final ByteBuffer outBuffer = CompressionUtils.allocateOutputBuffer(inBuffer.remaining()); + outBuffer.order(ByteOrder.BIG_ENDIAN); + final int formatFlags = rangeParams.getFormatFlags(); + outBuffer.put((byte) (formatFlags)); + + if (!rangeParams.isNosz()) { + // original size is not recorded + CompressionUtils.writeUint7(inBuffer.remaining(), outBuffer); + } + + ByteBuffer inputBuffer = inBuffer; + + // Stripe flag is not implemented in the write implementation + if (rangeParams.isStripe()) { + throw new CRAMException("Range Encoding with Stripe Flag is not implemented."); + } + + final int inSize = inputBuffer.remaining(); // e_len -> inSize + + // Pack + if (rangeParams.isPack()) { + final int[] frequencyTable = new int[Constants.NUMBER_OF_SYMBOLS]; + for (int i = 0; i < inSize; i++) { + frequencyTable[inputBuffer.get(i) & 0xFF]++; + } + int numSymbols = 0; + final int[] packMappingTable = new int[Constants.NUMBER_OF_SYMBOLS]; + for (int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i++) { + if (frequencyTable[i] > 0) { + packMappingTable[i] = numSymbols++; + } + } + + // skip Packing if numSymbols = 0 or numSymbols > 16 + if (numSymbols != 0 && numSymbols <= 16) { + inputBuffer = CompressionUtils.encodePack(inputBuffer, outBuffer, frequencyTable, packMappingTable, numSymbols); + } else { + // unset pack flag in the first byte of the outBuffer + outBuffer.put(0, (byte) (outBuffer.get(0) & ~RangeParams.PACK_FLAG_MASK)); + } + } + + if (rangeParams.isCAT()) { + + // Data is uncompressed + outBuffer.put(inputBuffer); + outBuffer.limit(outBuffer.position()); + outBuffer.rewind(); // set position to 0 + } else if (rangeParams.isExternalCompression()) { + final byte[] rawBytes = new byte[inputBuffer.remaining()]; + inputBuffer.get(rawBytes, inBuffer.position(), inputBuffer.remaining()); + final BZIP2ExternalCompressor compressor = new BZIP2ExternalCompressor(); + final byte[] extCompressedBytes = compressor.compress(rawBytes); + outBuffer.put(extCompressedBytes); + outBuffer.limit(outBuffer.position()); + outBuffer.rewind(); // set position to 0 + } else if (rangeParams.isRLE()) { + switch (rangeParams.getOrder()) { + case ZERO: + compressRLEOrder0(inputBuffer, outBuffer); + break; + case ONE: + compressRLEOrder1(inputBuffer, outBuffer); + break; + default: + throw new CRAMException("Unknown range order: " + rangeParams.getOrder()); + } + } else { + switch (rangeParams.getOrder()) { + case ZERO: + compressOrder0(inputBuffer, outBuffer); + break; + case ONE: + compressOrder1(inputBuffer, outBuffer); + break; + default: + throw new CRAMException("Unknown range order: " + rangeParams.getOrder()); + } + } + return outBuffer; + } + + private void compressOrder0( + final ByteBuffer inBuffer, + final ByteBuffer outBuffer) { + + int maxSymbol = 0; + final int inSize = inBuffer.remaining(); + for (int i = 0; i < inSize; i++) { + if (maxSymbol < (inBuffer.get(i) & 0xFF)) { + maxSymbol = inBuffer.get(i) & 0xFF; + } + } + maxSymbol++; + final ByteModel byteModel = new ByteModel(maxSymbol); + outBuffer.put((byte) maxSymbol); + final RangeCoder rangeCoder = new RangeCoder(); + for (int i = 0; i < inSize; i++) { + byteModel.modelEncode(outBuffer, rangeCoder, inBuffer.get(i) & 0xFF); + } + rangeCoder.rangeEncodeEnd(outBuffer); + outBuffer.limit(outBuffer.position()); + outBuffer.rewind(); + } + + private void compressOrder1( + final ByteBuffer inBuffer, + final ByteBuffer outBuffer) { + int maxSymbol = 0; + final int inSize = inBuffer.remaining(); + for (int i = 0; i < inSize; i++) { + if (maxSymbol < (inBuffer.get(i) & 0xFF)) { + maxSymbol = inBuffer.get(i) & 0xFF; + } + } + maxSymbol++; + final List byteModelList = new ArrayList(); + for (int i = 0; i < maxSymbol; i++) { + byteModelList.add(i, new ByteModel(maxSymbol)); + } + outBuffer.put((byte) maxSymbol); + final RangeCoder rangeCoder = new RangeCoder(); + int last = 0; + for (int i = 0; i < inSize; i++) { + byteModelList.get(last).modelEncode(outBuffer, rangeCoder, inBuffer.get(i) & 0xFF); + last = inBuffer.get(i) & 0xFF; + } + rangeCoder.rangeEncodeEnd(outBuffer); + outBuffer.limit(outBuffer.position()); + outBuffer.rewind(); + } + + private void compressRLEOrder0( + final ByteBuffer inBuffer, + final ByteBuffer outBuffer) { + int maxSymbols = 0; + final int inSize = inBuffer.remaining(); + for (int i = 0; i < inSize; i++) { + if (maxSymbols < (inBuffer.get(i) & 0xFF)) { + maxSymbols = inBuffer.get(i) & 0xFF; + } + } + maxSymbols++; // FIXME not what spec states! + + final ByteModel modelLit = new ByteModel(maxSymbols); + final List byteModelRunsList = new ArrayList(258); + + for (int i = 0; i <= 257; i++) { + byteModelRunsList.add(i, new ByteModel(4)); + } + outBuffer.put((byte) maxSymbols); + final RangeCoder rangeCoder = new RangeCoder(); + int i = 0; + while (i < inSize) { + modelLit.modelEncode(outBuffer, rangeCoder, inBuffer.get(i) & 0xFF); + int run = 1; + while (i + run < inSize && (inBuffer.get(i + run) & 0xFF) == (inBuffer.get(i) & 0xFF)) { + run++; + } + run--; // Check this!! + int rctx = inBuffer.get(i) & 0xFF; + i += run + 1; + int part = run >= 3 ? 3 : run; + byteModelRunsList.get(rctx).modelEncode(outBuffer, rangeCoder, part); + run -= part; + rctx = 256; + while (part == 3) { + part = run >= 3 ? 3 : run; + byteModelRunsList.get(rctx).modelEncode(outBuffer, rangeCoder, part); + rctx = 257; + run -= part; + } + } + rangeCoder.rangeEncodeEnd(outBuffer); + outBuffer.limit(outBuffer.position()); + outBuffer.rewind(); + } + + private void compressRLEOrder1( + final ByteBuffer inBuffer, + final ByteBuffer outBuffer) { + int maxSymbols = 0; + final int inSize = inBuffer.remaining(); + for (int i = 0; i < inSize; i++) { + if (maxSymbols < (inBuffer.get(i) & 0xFF)) { + maxSymbols = inBuffer.get(i) & 0xFF; + } + } + maxSymbols++; // FIXME not what spec states! + + final List modelLitList = new ArrayList<>(maxSymbols); + for (int i = 0; i < maxSymbols; i++) { + modelLitList.add(i, new ByteModel(maxSymbols)); + } + final List byteModelRunsList = new ArrayList(258); + for (int i = 0; i <= 257; i++) { + byteModelRunsList.add(i, new ByteModel(4)); + } + outBuffer.put((byte) maxSymbols); + final RangeCoder rangeCoder = new RangeCoder(); + int i = 0; + int last = 0; + while (i < inSize) { + modelLitList.get(last).modelEncode(outBuffer, rangeCoder, inBuffer.get(i) & 0xFF); + int run = 1; + while (i + run < inSize && inBuffer.get(i + run) == inBuffer.get(i)) { + run++; + } + run--; // Check this!! + int rctx = inBuffer.get(i) & 0xFF; + last = inBuffer.get(i) & 0xFF; + i += run + 1; + int part = run >= 3 ? 3 : run; + byteModelRunsList.get(rctx).modelEncode(outBuffer, rangeCoder, part); + run -= part; + rctx = 256; + while (part == 3) { + part = run >= 3 ? 3 : run; + byteModelRunsList.get(rctx).modelEncode(outBuffer, rangeCoder, part); + rctx = 257; + run -= part; + } + } + rangeCoder.rangeEncodeEnd(outBuffer); + outBuffer.limit(outBuffer.position()); + outBuffer.rewind(); + } + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/range/RangeParams.java b/src/main/java/htsjdk/samtools/cram/compression/range/RangeParams.java new file mode 100644 index 0000000000..7759f8c853 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/range/RangeParams.java @@ -0,0 +1,82 @@ +package htsjdk.samtools.cram.compression.range; + +public class RangeParams { + public static final int ORDER_FLAG_MASK = 0x01; + public static final int EXT_FLAG_MASK = 0x04; + public static final int STRIPE_FLAG_MASK = 0x08; + public static final int NOSZ_FLAG_MASK = 0x10; + public static final int CAT_FLAG_MASK = 0x20; + public static final int RLE_FLAG_MASK = 0x40; + public static final int PACK_FLAG_MASK = 0x80; + + + // format is the first byte of the compressed data stream, + // which consists of all the bit-flags detailing the type of transformations + // and entropy encoders to be combined + private int formatFlags; + + private static final int FORMAT_FLAG_MASK = 0xFF; + + public enum ORDER { + ZERO, ONE; + + public static RangeParams.ORDER fromInt(final int orderValue) { + try { + ORDER[] x = ORDER.values(); + return x[orderValue]; + } catch (final ArrayIndexOutOfBoundsException e) { + throw new IllegalArgumentException("Unknown Range order: " + orderValue, e); + } + } + } + + public RangeParams(final int formatFlags) { + this.formatFlags = formatFlags; + } + + @Override + public String toString() { + return "RangeParams{" + "formatFlags=" + formatFlags + "}"; + } + + public int getFormatFlags(){ + // first byte of the encoded stream + return formatFlags & FORMAT_FLAG_MASK; + } + + public ORDER getOrder() { + // Range Order ZERO or ONE encoding + return ORDER.fromInt(formatFlags & ORDER_FLAG_MASK); //convert into order type + } + + public boolean isExternalCompression(){ + // “External” compression via bzip2 + return ((formatFlags & EXT_FLAG_MASK)!=0); + } + + public boolean isStripe(){ + // multiway interleaving of byte streams + return ((formatFlags & STRIPE_FLAG_MASK)!=0); + } + + public boolean isNosz(){ + // original size is not recorded (for use by Stripe) + return ((formatFlags & NOSZ_FLAG_MASK)!=0); + } + + public boolean isCAT(){ + // Data is uncompressed + return ((formatFlags & CAT_FLAG_MASK)!=0); + } + + public boolean isRLE(){ + // Run length encoding, with runs and literals encoded separately + return ((formatFlags & RLE_FLAG_MASK)!=0); + } + + public boolean isPack(){ + // Pack 2, 4, 8 or infinite symbols per byte + return ((formatFlags & PACK_FLAG_MASK)!=0); + } + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/ArithmeticDecoder.java b/src/main/java/htsjdk/samtools/cram/compression/rans/ArithmeticDecoder.java index 1e57c0886a..bfc7f33795 100644 --- a/src/main/java/htsjdk/samtools/cram/compression/rans/ArithmeticDecoder.java +++ b/src/main/java/htsjdk/samtools/cram/compression/rans/ArithmeticDecoder.java @@ -24,25 +24,25 @@ */ package htsjdk.samtools.cram.compression.rans; -final class ArithmeticDecoder { - final FC[] fc = new FC[256]; +final public class ArithmeticDecoder { + public final int[] frequencies = new int[Constants.NUMBER_OF_SYMBOLS]; - // reverse lookup table ? - byte[] R = new byte[Constants.TOTFREQ]; + // reverse lookup table + public final byte[] reverseLookup = new byte[Constants.TOTAL_FREQ]; public ArithmeticDecoder() { - for (int i = 0; i < 256; i++) { - fc[i] = new FC(); + for (int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i++) { + frequencies[i] = 0; } } public void reset() { - for (int i = 0; i < 256; i++) { - fc[i].reset(); + for (int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i++) { + frequencies[i] = 0; } - for (int i = 0; i < Constants.TOTFREQ; i++) { - R[i] = 0; + for (int i = 0; i < Constants.TOTAL_FREQ; i++) { + reverseLookup[i] = 0; } } -} +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/Constants.java b/src/main/java/htsjdk/samtools/cram/compression/rans/Constants.java index 7c7545bfbe..f970582f48 100644 --- a/src/main/java/htsjdk/samtools/cram/compression/rans/Constants.java +++ b/src/main/java/htsjdk/samtools/cram/compression/rans/Constants.java @@ -1,7 +1,13 @@ package htsjdk.samtools.cram.compression.rans; -final class Constants { - static final int TF_SHIFT = 12; - static final int TOTFREQ = (1 << TF_SHIFT); // 4096 - static final int RANS_BYTE_L = 1 << 23; -} +final public class Constants { + public static final int TOTAL_FREQ_SHIFT = 12; + public static final int TOTAL_FREQ = (1 << TOTAL_FREQ_SHIFT); // 4096 + public static final int NUMBER_OF_SYMBOLS = 256; + public static final int RANS_4x8_LOWER_BOUND = 1 << 23; + public static final int RANS_4x8_ORDER_BYTE_LENGTH = 1; + public static final int RANS_4x8_COMPRESSED_BYTE_LENGTH = 4; + public static final int RANS_4x8_RAW_BYTE_LENGTH = 4; + public static final int RANS_4x8_PREFIX_BYTE_LENGTH = RANS_4x8_ORDER_BYTE_LENGTH + RANS_4x8_COMPRESSED_BYTE_LENGTH + RANS_4x8_RAW_BYTE_LENGTH; + public static final int RANS_Nx16_LOWER_BOUND = 1 << 15; +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/D04.java b/src/main/java/htsjdk/samtools/cram/compression/rans/D04.java deleted file mode 100644 index e9d9941575..0000000000 --- a/src/main/java/htsjdk/samtools/cram/compression/rans/D04.java +++ /dev/null @@ -1,80 +0,0 @@ -package htsjdk.samtools.cram.compression.rans; - -import java.nio.ByteBuffer; - -final class D04 { - static void uncompress( - final ByteBuffer inBuffer, - final ArithmeticDecoder D, - final RANSDecodingSymbol[] syms, - final ByteBuffer outBuffer) { - int rans0, rans1, rans2, rans3; - rans0 = inBuffer.getInt(); - rans1 = inBuffer.getInt(); - rans2 = inBuffer.getInt(); - rans3 = inBuffer.getInt(); - - final int out_sz = outBuffer.remaining(); - final int out_end = (out_sz & ~3); - for (int i = 0; i < out_end; i += 4) { - final byte c0 = D.R[Utils.RANSDecodeGet(rans0, Constants.TF_SHIFT)]; - final byte c1 = D.R[Utils.RANSDecodeGet(rans1, Constants.TF_SHIFT)]; - final byte c2 = D.R[Utils.RANSDecodeGet(rans2, Constants.TF_SHIFT)]; - final byte c3 = D.R[Utils.RANSDecodeGet(rans3, Constants.TF_SHIFT)]; - - outBuffer.put(i, c0); - outBuffer.put(i + 1, c1); - outBuffer.put(i + 2, c2); - outBuffer.put(i + 3, c3); - - rans0 = syms[0xFF & c0].advanceSymbolStep(rans0, Constants.TF_SHIFT); - rans1 = syms[0xFF & c1].advanceSymbolStep(rans1, Constants.TF_SHIFT); - rans2 = syms[0xFF & c2].advanceSymbolStep(rans2, Constants.TF_SHIFT); - rans3 = syms[0xFF & c3].advanceSymbolStep(rans3, Constants.TF_SHIFT); - - rans0 = Utils.RANSDecodeRenormalize(rans0, inBuffer); - rans1 = Utils.RANSDecodeRenormalize(rans1, inBuffer); - rans2 = Utils.RANSDecodeRenormalize(rans2, inBuffer); - rans3 = Utils.RANSDecodeRenormalize(rans3, inBuffer); - } - - outBuffer.position(out_end); - byte c; - switch (out_sz & 3) { - case 0: - break; - - case 1: - c = D.R[Utils.RANSDecodeGet(rans0, Constants.TF_SHIFT)]; - syms[0xFF & c].advanceSymbol(rans0, inBuffer, Constants.TF_SHIFT); - outBuffer.put(c); - break; - - case 2: - c = D.R[Utils.RANSDecodeGet(rans0, Constants.TF_SHIFT)]; - syms[0xFF & c].advanceSymbol(rans0, inBuffer, Constants.TF_SHIFT); - outBuffer.put(c); - - c = D.R[Utils.RANSDecodeGet(rans1, Constants.TF_SHIFT)]; - syms[0xFF & c].advanceSymbol(rans1, inBuffer, Constants.TF_SHIFT); - outBuffer.put(c); - break; - - case 3: - c = D.R[Utils.RANSDecodeGet(rans0, Constants.TF_SHIFT)]; - syms[0xFF & c].advanceSymbol(rans0, inBuffer, Constants.TF_SHIFT); - outBuffer.put(c); - - c = D.R[Utils.RANSDecodeGet(rans1, Constants.TF_SHIFT)]; - syms[0xFF & c].advanceSymbol(rans1, inBuffer, Constants.TF_SHIFT); - outBuffer.put(c); - - c = D.R[Utils.RANSDecodeGet(rans2, Constants.TF_SHIFT)]; - syms[0xFF & c].advanceSymbol(rans2, inBuffer, Constants.TF_SHIFT); - outBuffer.put(c); - break; - } - - outBuffer.position(0); - } -} diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/D14.java b/src/main/java/htsjdk/samtools/cram/compression/rans/D14.java deleted file mode 100644 index ba7d598d9e..0000000000 --- a/src/main/java/htsjdk/samtools/cram/compression/rans/D14.java +++ /dev/null @@ -1,64 +0,0 @@ -package htsjdk.samtools.cram.compression.rans; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -final class D14 { - static void uncompress( - final ByteBuffer inBuffer, - final ByteBuffer outBuffer, - final ArithmeticDecoder[] D, - final RANSDecodingSymbol[][] syms) { - final int out_sz = outBuffer.remaining(); - int rans0, rans1, rans2, rans7; - inBuffer.order(ByteOrder.LITTLE_ENDIAN); - rans0 = inBuffer.getInt(); - rans1 = inBuffer.getInt(); - rans2 = inBuffer.getInt(); - rans7 = inBuffer.getInt(); - - final int isz4 = out_sz >> 2; - int i0 = 0; - int i1 = isz4; - int i2 = 2 * isz4; - int i7 = 3 * isz4; - int l0 = 0; - int l1 = 0; - int l2 = 0; - int l7 = 0; - for (; i0 < isz4; i0++, i1++, i2++, i7++) { - final int c0 = 0xFF & D[l0].R[Utils.RANSDecodeGet(rans0, Constants.TF_SHIFT)]; - final int c1 = 0xFF & D[l1].R[Utils.RANSDecodeGet(rans1, Constants.TF_SHIFT)]; - final int c2 = 0xFF & D[l2].R[Utils.RANSDecodeGet(rans2, Constants.TF_SHIFT)]; - final int c7 = 0xFF & D[l7].R[Utils.RANSDecodeGet(rans7, Constants.TF_SHIFT)]; - - outBuffer.put(i0, (byte) c0); - outBuffer.put(i1, (byte) c1); - outBuffer.put(i2, (byte) c2); - outBuffer.put(i7, (byte) c7); - - rans0 = syms[l0][c0].advanceSymbolStep(rans0, Constants.TF_SHIFT); - rans1 = syms[l1][c1].advanceSymbolStep(rans1, Constants.TF_SHIFT); - rans2 = syms[l2][c2].advanceSymbolStep(rans2, Constants.TF_SHIFT); - rans7 = syms[l7][c7].advanceSymbolStep(rans7, Constants.TF_SHIFT); - - rans0 = Utils.RANSDecodeRenormalize(rans0, inBuffer); - rans1 = Utils.RANSDecodeRenormalize(rans1, inBuffer); - rans2 = Utils.RANSDecodeRenormalize(rans2, inBuffer); - rans7 = Utils.RANSDecodeRenormalize(rans7, inBuffer); - - l0 = c0; - l1 = c1; - l2 = c2; - l7 = c7; - } - - // Remainder - for (; i7 < out_sz; i7++) { - final int c7 = 0xFF & D[l7].R[Utils.RANSDecodeGet(rans7, Constants.TF_SHIFT)]; - outBuffer.put(i7, (byte) c7); - rans7 = syms[l7][c7].advanceSymbol(rans7, inBuffer, Constants.TF_SHIFT); - l7 = c7; - } - } -} diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/E04.java b/src/main/java/htsjdk/samtools/cram/compression/rans/E04.java deleted file mode 100644 index 709c7096b0..0000000000 --- a/src/main/java/htsjdk/samtools/cram/compression/rans/E04.java +++ /dev/null @@ -1,52 +0,0 @@ -package htsjdk.samtools.cram.compression.rans; - -import java.nio.ByteBuffer; - -final class E04 { - - static int compress(final ByteBuffer inBuffer, final RANSEncodingSymbol[] syms, final ByteBuffer cp) { - final int cdata_size; - final int in_size = inBuffer.remaining(); - int rans0, rans1, rans2, rans3; - final ByteBuffer ptr = cp.slice(); - - rans0 = Constants.RANS_BYTE_L; - rans1 = Constants.RANS_BYTE_L; - rans2 = Constants.RANS_BYTE_L; - rans3 = Constants.RANS_BYTE_L; - - int i; - switch (i = (in_size & 3)) { - case 3: - rans2 = syms[0xFF & inBuffer.get(in_size - (i - 2))].putSymbol(rans2, ptr); - case 2: - rans1 = syms[0xFF & inBuffer.get(in_size - (i - 1))].putSymbol(rans1, ptr); - case 1: - rans0 = syms[0xFF & inBuffer.get(in_size - (i))].putSymbol(rans0, ptr); - case 0: - break; - } - for (i = (in_size & ~3); i > 0; i -= 4) { - final int c3 = 0xFF & inBuffer.get(i - 1); - final int c2 = 0xFF & inBuffer.get(i - 2); - final int c1 = 0xFF & inBuffer.get(i - 3); - final int c0 = 0xFF & inBuffer.get(i - 4); - - rans3 = syms[c3].putSymbol(rans3, ptr); - rans2 = syms[c2].putSymbol(rans2, ptr); - rans1 = syms[c1].putSymbol(rans1, ptr); - rans0 = syms[c0].putSymbol(rans0, ptr); - } - - ptr.putInt(rans3); - ptr.putInt(rans2); - ptr.putInt(rans1); - ptr.putInt(rans0); - ptr.flip(); - cdata_size = ptr.limit(); - // reverse the compressed bytes, so that they become in REVERSE order: - Utils.reverse(ptr); - inBuffer.position(inBuffer.limit()); - return cdata_size; - } -} diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/E14.java b/src/main/java/htsjdk/samtools/cram/compression/rans/E14.java deleted file mode 100644 index 37f2767137..0000000000 --- a/src/main/java/htsjdk/samtools/cram/compression/rans/E14.java +++ /dev/null @@ -1,87 +0,0 @@ -package htsjdk.samtools.cram.compression.rans; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -final class E14 { - - static int compress(final ByteBuffer inBuffer, final RANSEncodingSymbol[][] syms, final ByteBuffer outBuffer) { - final int in_size = inBuffer.remaining(); - final int compressedBlobSize; - int rans0, rans1, rans2, rans3; - rans0 = Constants.RANS_BYTE_L; - rans1 = Constants.RANS_BYTE_L; - rans2 = Constants.RANS_BYTE_L; - rans3 = Constants.RANS_BYTE_L; - - /* - * Slicing is needed for buffer reversing later. - */ - final ByteBuffer ptr = outBuffer.slice(); - - final int isz4 = in_size >> 2; - int i0 = isz4 - 2; - int i1 = 2 * isz4 - 2; - int i2 = 3 * isz4 - 2; - int i3 = 4 * isz4 - 2; - - int l0 = 0; - if (i0 + 1 >= 0) { - l0 = 0xFF & inBuffer.get(i0 + 1); - } - int l1 = 0; - if (i1 + 1 >= 0) { - l1 = 0xFF & inBuffer.get(i1 + 1); - } - int l2 = 0; - if (i2 + 1 >= 0) { - l2 = 0xFF & inBuffer.get(i2 + 1); - } - int l3; - - // Deal with the remainder - l3 = 0xFF & inBuffer.get(in_size - 1); - for (i3 = in_size - 2; i3 > 4 * isz4 - 2 && i3 >= 0; i3--) { - final int c3 = 0xFF & inBuffer.get(i3); - rans3 = syms[c3][l3].putSymbol(rans3, ptr); - l3 = c3; - } - - for (; i0 >= 0; i0--, i1--, i2--, i3--) { - final int c0 = 0xFF & inBuffer.get(i0); - final int c1 = 0xFF & inBuffer.get(i1); - final int c2 = 0xFF & inBuffer.get(i2); - final int c3 = 0xFF & inBuffer.get(i3); - - rans3 = syms[c3][l3].putSymbol(rans3, ptr); - rans2 = syms[c2][l2].putSymbol(rans2, ptr); - rans1 = syms[c1][l1].putSymbol(rans1, ptr); - rans0 = syms[c0][l0].putSymbol(rans0, ptr); - - l0 = c0; - l1 = c1; - l2 = c2; - l3 = c3; - } - - rans3 = syms[0][l3].putSymbol(rans3, ptr); - rans2 = syms[0][l2].putSymbol(rans2, ptr); - rans1 = syms[0][l1].putSymbol(rans1, ptr); - rans0 = syms[0][l0].putSymbol(rans0, ptr); - - ptr.order(ByteOrder.BIG_ENDIAN); - ptr.putInt(rans3); - ptr.putInt(rans2); - ptr.putInt(rans1); - ptr.putInt(rans0); - ptr.flip(); - compressedBlobSize = ptr.limit(); - Utils.reverse(ptr); - /* - * Depletion of the in buffer cannot be confirmed because of the get(int - * position) method use during encoding, hence enforcing: - */ - inBuffer.position(inBuffer.limit()); - return compressedBlobSize; - } -} diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/FC.java b/src/main/java/htsjdk/samtools/cram/compression/rans/FC.java deleted file mode 100644 index dc08e5f132..0000000000 --- a/src/main/java/htsjdk/samtools/cram/compression/rans/FC.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2019 The Broad Institute - * - * Permission is hereby granted, free of charge, to any person - * obtaining a copy of this software and associated documentation - * files (the "Software"), to deal in the Software without - * restriction, including without limitation the rights to use, - * copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the - * Software is furnished to do so, subject to the following - * conditions: - * - * The above copyright notice and this permission notice shall be - * included in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES - * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND - * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT - * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, - * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR - * THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ -package htsjdk.samtools.cram.compression.rans; - -final class FC { - int F, C; - - public void reset() { - F = C = 0; - } - -} diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/Frequencies.java b/src/main/java/htsjdk/samtools/cram/compression/rans/Frequencies.java deleted file mode 100644 index c174ad1396..0000000000 --- a/src/main/java/htsjdk/samtools/cram/compression/rans/Frequencies.java +++ /dev/null @@ -1,319 +0,0 @@ -package htsjdk.samtools.cram.compression.rans; - -import java.nio.ByteBuffer; -import java.util.Arrays; - -// T = total of true counts -// F = scaled integer frequencies -// M = sum(fs) - -final class Frequencies { - - static void readStatsOrder0(final ByteBuffer cp, final ArithmeticDecoder decoder, final RANSDecodingSymbol[] decodingSymbols) { - // Pre-compute reverse lookup of frequency. - int rle = 0; - int x = 0; - int j = cp.get() & 0xFF; - do { - if ((decoder.fc[j].F = (cp.get() & 0xFF)) >= 128) { - decoder.fc[j].F &= ~128; - decoder.fc[j].F = ((decoder.fc[j].F & 127) << 8) | (cp.get() & 0xFF); - } - decoder.fc[j].C = x; - - decodingSymbols[j].set(decoder.fc[j].C, decoder.fc[j].F); - - /* Build reverse lookup table */ - Arrays.fill(decoder.R, x, x + decoder.fc[j].F, (byte) j); - - x += decoder.fc[j].F; - - if (rle == 0 && j + 1 == (0xFF & cp.get(cp.position()))) { - j = cp.get() & 0xFF; - rle = cp.get() & 0xFF; - } else if (rle != 0) { - rle--; - j++; - } else { - j = cp.get() & 0xFF; - } - } while (j != 0); - - assert (x < Constants.TOTFREQ); - } - - static void readStatsOrder1(final ByteBuffer cp, final ArithmeticDecoder[] D, final RANSDecodingSymbol[][] decodingSymbols) { - int rle_i = 0; - int i = 0xFF & cp.get(); - do { - int rle_j = 0; - int x = 0; - int j = 0xFF & cp.get(); - do { - if ((D[i].fc[j].F = (0xFF & cp.get())) >= 128) { - D[i].fc[j].F &= ~128; - D[i].fc[j].F = ((D[i].fc[j].F & 127) << 8) | (0xFF & cp.get()); - } - D[i].fc[j].C = x; - - if (D[i].fc[j].F == 0) { - D[i].fc[j].F = Constants.TOTFREQ; - } - - decodingSymbols[i][j].set( - D[i].fc[j].C, - D[i].fc[j].F - ); - - /* Build reverse lookup table */ - Arrays.fill(D[i].R, x, x + D[i].fc[j].F, (byte) j); - - x += D[i].fc[j].F; - assert (x <= Constants.TOTFREQ); - - if (rle_j == 0 && j + 1 == (0xFF & cp.get(cp.position()))) { - j = (0xFF & cp.get()); - rle_j = (0xFF & cp.get()); - } else if (rle_j != 0) { - rle_j--; - j++; - } else { - j = (0xFF & cp.get()); - } - } while (j != 0); - - if (rle_i == 0 && i + 1 == (0xFF & cp.get(cp.position()))) { - i = (0xFF & cp.get()); - rle_i = (0xFF & cp.get()); - } else if (rle_i != 0) { - rle_i--; - i++; - } else { - i = (0xFF & cp.get()); - } - } while (i != 0); - } - - static int[] calcFrequenciesOrder0(final ByteBuffer inBuffer) { - final int inSize = inBuffer.remaining(); - - // Compute statistics - final int[] F = new int[RANS.NUMBER_OF_SYMBOLS]; - int T = 0; - for (int i = 0; i < inSize; i++) { - F[0xFF & inBuffer.get()]++; - T++; - } - final long tr = ((long) Constants.TOTFREQ << 31) / T + (1 << 30) / T; - - // Normalise so T[i] == TOTFREQ - int m = 0; - int M = 0; // frequency denominator ? - for (int j = 0; j < RANS.NUMBER_OF_SYMBOLS; j++) { - if (m < F[j]) { - m = F[j]; - M = j; - } - } - - int fsum = 0; - for (int j = 0; j < RANS.NUMBER_OF_SYMBOLS; j++) { - if (F[j] == 0) { - continue; - } - if ((F[j] = (int) ((F[j] * tr) >> 31)) == 0) { - F[j] = 1; - } - fsum += F[j]; - } - - fsum++; - if (fsum < Constants.TOTFREQ) { - F[M] += Constants.TOTFREQ - fsum; - } else { - F[M] -= fsum - Constants.TOTFREQ; - } - - assert (F[M] > 0); - return F; - } - - static int[][] calcFrequenciesOrder1(final ByteBuffer in) { - final int in_size = in.remaining(); - - final int[][] F = new int[RANS.NUMBER_OF_SYMBOLS][RANS.NUMBER_OF_SYMBOLS]; - final int[] T = new int[RANS.NUMBER_OF_SYMBOLS]; - int c; - - int last_i = 0; - for (int i = 0; i < in_size; i++) { - F[last_i][c = (0xFF & in.get())]++; - T[last_i]++; - last_i = c; - } - F[0][0xFF & in.get((in_size >> 2))]++; - F[0][0xFF & in.get(2 * (in_size >> 2))]++; - F[0][0xFF & in.get(3 * (in_size >> 2))]++; - T[0] += 3; - - for (int i = 0; i < RANS.NUMBER_OF_SYMBOLS; i++) { - if (T[i] == 0) { - continue; - } - - final double p = ((double) Constants.TOTFREQ) / T[i]; - int t2 = 0, m = 0, M = 0; - for (int j = 0; j < RANS.NUMBER_OF_SYMBOLS; j++) { - if (F[i][j] == 0) - continue; - - if (m < F[i][j]) { - m = F[i][j]; - M = j; - } - - if ((F[i][j] *= p) == 0) - F[i][j] = 1; - t2 += F[i][j]; - } - - t2++; - if (t2 < Constants.TOTFREQ) { - F[i][M] += Constants.TOTFREQ - t2; - } else { - F[i][M] -= t2 - Constants.TOTFREQ; - } - } - - return F; - } - - static RANSEncodingSymbol[] buildSymsOrder0(final int[] F, final RANSEncodingSymbol[] syms) { - final int[] C = new int[RANS.NUMBER_OF_SYMBOLS]; - - int T = 0; - for (int j = 0; j < RANS.NUMBER_OF_SYMBOLS; j++) { - C[j] = T; - T += F[j]; - if (F[j] != 0) { - syms[j].set(C[j], F[j], Constants.TF_SHIFT); - } - } - return syms; - } - - static int writeFrequenciesOrder0(final ByteBuffer cp, final int[] F) { - final int start = cp.position(); - - int rle = 0; - for (int j = 0; j < RANS.NUMBER_OF_SYMBOLS; j++) { - if (F[j] != 0) { - // j - if (rle != 0) { - rle--; - } else { - cp.put((byte) j); - if (rle == 0 && j != 0 && F[j - 1] != 0) { - for (rle = j + 1; rle < 256 && F[rle] != 0; rle++) - ; - rle -= j + 1; - cp.put((byte) rle); - } - } - - // F[j] - if (F[j] < 128) { - cp.put((byte) (F[j])); - } else { - cp.put((byte) (128 | (F[j] >> 8))); - cp.put((byte) (F[j] & 0xff)); - } - } - } - - cp.put((byte) 0); - return cp.position() - start; - } - - static RANSEncodingSymbol[][] buildSymsOrder1(final int[][] F, final RANSEncodingSymbol[][] syms) { - for (int i = 0; i < RANS.NUMBER_OF_SYMBOLS; i++) { - final int[] F_i_ = F[i]; - int x = 0; - for (int j = 0; j < RANS.NUMBER_OF_SYMBOLS; j++) { - if (F_i_[j] != 0) { - syms[i][j].set(x, F_i_[j], Constants.TF_SHIFT); - x += F_i_[j]; - } - } - } - - return syms; - } - - static int writeFrequenciesOrder1(final ByteBuffer cp, final int[][] F) { - final int start = cp.position(); - final int[] T = new int[RANS.NUMBER_OF_SYMBOLS]; - - for (int i = 0; i < RANS.NUMBER_OF_SYMBOLS; i++) { - for (int j = 0; j < RANS.NUMBER_OF_SYMBOLS; j++) { - T[i] += F[i][j]; - } - } - - int rle_i = 0; - for (int i = 0; i < RANS.NUMBER_OF_SYMBOLS; i++) { - if (T[i] == 0) { - continue; - } - - // Store frequency table - // i - if (rle_i != 0) { - rle_i--; - } else { - cp.put((byte) i); - // FIXME: could use order-0 statistics to observe which alphabet - // symbols are present and base RLE on that ordering instead. - if (i != 0 && T[i - 1] != 0) { - for (rle_i = i + 1; rle_i < 256 && T[rle_i] != 0; rle_i++) - ; - rle_i -= i + 1; - cp.put((byte) rle_i); - } - } - - final int[] F_i_ = F[i]; - int rle_j = 0; - for (int j = 0; j < RANS.NUMBER_OF_SYMBOLS; j++) { - if (F_i_[j] != 0) { - - // j - if (rle_j != 0) { - rle_j--; - } else { - cp.put((byte) j); - if (rle_j == 0 && j != 0 && F_i_[j - 1] != 0) { - for (rle_j = j + 1; rle_j < 256 && F_i_[rle_j] != 0; rle_j++) - ; - rle_j -= j + 1; - cp.put((byte) rle_j); - } - } - - // F_i_[j] - if (F_i_[j] < 128) { - cp.put((byte) F_i_[j]); - } else { - cp.put((byte) (128 | (F_i_[j] >> 8))); - cp.put((byte) (F_i_[j] & 0xff)); - } - } - } - cp.put((byte) 0); - } - cp.put((byte) 0); - - return cp.position() - start; - } - -} diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/RANS.java b/src/main/java/htsjdk/samtools/cram/compression/rans/RANS.java deleted file mode 100644 index 8a4e719ff5..0000000000 --- a/src/main/java/htsjdk/samtools/cram/compression/rans/RANS.java +++ /dev/null @@ -1,233 +0,0 @@ -package htsjdk.samtools.cram.compression.rans; - -import htsjdk.utils.ValidationUtils; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -public final class RANS { - - public enum ORDER { - ZERO, ONE; - - public static ORDER fromInt(final int orderValue) { - try { - return ORDER.values()[orderValue]; - } catch (final ArrayIndexOutOfBoundsException e) { - throw new IllegalArgumentException("Unknown rANS order: " + orderValue); - } - } - } - - // A compressed rANS stream consists of a prefix containing 3 values, followed by the compressed data block: - // byte - order of the codec (0 or 1) - // int - total compressed size of the frequency table and compressed content - // int - total size of the raw/uncompressed content - // byte[] - frequency table (RLE) - // byte[] - compressed data - - private static final int ORDER_BYTE_LENGTH = 1; - private static final int COMPRESSED_BYTE_LENGTH = 4; - private static final int RAW_BYTE_LENGTH = 4; - private static final int PREFIX_BYTE_LENGTH = ORDER_BYTE_LENGTH + COMPRESSED_BYTE_LENGTH + RAW_BYTE_LENGTH; - - // streams smaller than this value don't have sufficient symbol context for ORDER-1 encoding, - // so always use ORDER-0 - private static final int MINIMUM__ORDER_1_SIZE = 4; - private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0); - - public static final int NUMBER_OF_SYMBOLS = 256; - - // working variables used by the encoder and decoder; initialize them lazily since - // they consist of lots of small objects, and we don't want to instantiate them - // until we actually use them - private ArithmeticDecoder[] D; - private RANSDecodingSymbol[][] decodingSymbols; - private RANSEncodingSymbol[][] encodingSymbols; - - // Lazy initialization of working memory for the encoder/decoder - private void initializeRANSCoder() { - if (D == null) { - D = new ArithmeticDecoder[NUMBER_OF_SYMBOLS]; - for (int i = 0; i < NUMBER_OF_SYMBOLS; i++) { - D[i] = new ArithmeticDecoder(); - } - } else { - for (int i = 0; i < NUMBER_OF_SYMBOLS; i++) { - D[i].reset(); - } - } - if (decodingSymbols == null) { - decodingSymbols = new RANSDecodingSymbol[NUMBER_OF_SYMBOLS][NUMBER_OF_SYMBOLS]; - for (int i = 0; i < decodingSymbols.length; i++) { - for (int j = 0; j < decodingSymbols[i].length; j++) { - decodingSymbols[i][j] = new RANSDecodingSymbol(); - } - } - } else { - for (int i = 0; i < decodingSymbols.length; i++) { - for (int j = 0; j < decodingSymbols[i].length; j++) { - decodingSymbols[i][j].set(0, 0); - } - } - } - if (encodingSymbols == null) { - encodingSymbols = new RANSEncodingSymbol[NUMBER_OF_SYMBOLS][NUMBER_OF_SYMBOLS]; - for (int i = 0; i < encodingSymbols.length; i++) { - for (int j = 0; j < encodingSymbols[i].length; j++) { - encodingSymbols[i][j] = new RANSEncodingSymbol(); - } - } - } else { - for (int i = 0; i < encodingSymbols.length; i++) { - for (int j = 0; j < encodingSymbols[i].length; j++) { - encodingSymbols[i][j].reset(); - } - } - } - } - - public ByteBuffer uncompress(final ByteBuffer inBuffer) { - if (inBuffer.remaining() == 0) { - return EMPTY_BUFFER; - } - - initializeRANSCoder(); - - final ORDER order = ORDER.fromInt(inBuffer.get()); - - inBuffer.order(ByteOrder.LITTLE_ENDIAN); - final int inSize = inBuffer.getInt(); - if (inSize != inBuffer.remaining() - RAW_BYTE_LENGTH) { - throw new RuntimeException("Incorrect input length."); - } - final int outSize = inBuffer.getInt(); - final ByteBuffer outBuffer = ByteBuffer.allocate(outSize); - - switch (order) { - case ZERO: - return uncompressOrder0Way4(inBuffer, outBuffer); - - case ONE: - return uncompressOrder1Way4(inBuffer, outBuffer); - - default: - throw new RuntimeException("Unknown rANS order: " + order); - } - } - - public ByteBuffer compress(final ByteBuffer inBuffer, final ORDER order) { - if (inBuffer.remaining() == 0) { - return EMPTY_BUFFER; - } - - initializeRANSCoder(); - - if (inBuffer.remaining() < MINIMUM__ORDER_1_SIZE) { - // ORDER-1 encoding of less than 4 bytes is not permitted, so just use ORDER-0 - return compressOrder0Way4(inBuffer); - } - - switch (order) { - case ZERO: - return compressOrder0Way4(inBuffer); - - case ONE: - return compressOrder1Way4(inBuffer); - - default: - throw new RuntimeException("Unknown rANS order: " + order); - } - } - - private ByteBuffer compressOrder0Way4(final ByteBuffer inBuffer) { - final int inSize = inBuffer.remaining(); - final ByteBuffer outBuffer = allocateOutputBuffer(inSize); - - // move the output buffer ahead to the start of the frequency table (we'll come back and - // write the output stream prefix at the end of this method) - outBuffer.position(PREFIX_BYTE_LENGTH); // start of frequency table - - final int[] F = Frequencies.calcFrequenciesOrder0(inBuffer); - Frequencies.buildSymsOrder0(F, encodingSymbols[0]); - - final ByteBuffer cp = outBuffer.slice(); - final int frequencyTableSize = Frequencies.writeFrequenciesOrder0(cp, F); - - inBuffer.rewind(); - final int compressedBlobSize = E04.compress(inBuffer, encodingSymbols[0], cp); - - // rewind and write the prefix - writeCompressionPrefix(ORDER.ZERO, outBuffer, inSize, frequencyTableSize, compressedBlobSize); - return outBuffer; - } - - private ByteBuffer compressOrder1Way4(final ByteBuffer inBuffer) { - final int inSize = inBuffer.remaining(); - final ByteBuffer outBuffer = allocateOutputBuffer(inSize); - - // move to start of frequency - outBuffer.position(PREFIX_BYTE_LENGTH); - - final int[][] F = Frequencies.calcFrequenciesOrder1(inBuffer); - Frequencies.buildSymsOrder1(F, encodingSymbols); - - final ByteBuffer cp = outBuffer.slice(); - final int frequencyTableSize = Frequencies.writeFrequenciesOrder1(cp, F); - - inBuffer.rewind(); - final int compressedBlobSize = E14.compress(inBuffer, encodingSymbols, cp); - - // rewind and write the prefix - writeCompressionPrefix(ORDER.ONE, outBuffer, inSize, frequencyTableSize, compressedBlobSize); - return outBuffer; - } - - private ByteBuffer uncompressOrder0Way4(final ByteBuffer inBuffer, final ByteBuffer outBuffer) { - Frequencies.readStatsOrder0(inBuffer, D[0], decodingSymbols[0]); - D04.uncompress(inBuffer, D[0], decodingSymbols[0], outBuffer); - - return outBuffer; - } - - private ByteBuffer uncompressOrder1Way4(final ByteBuffer in, final ByteBuffer outBuffer) { - Frequencies.readStatsOrder1(in, D, decodingSymbols); - D14.uncompress(in, outBuffer, D, decodingSymbols); - return outBuffer; - } - - private static ByteBuffer allocateOutputBuffer(final int inSize) { - // This calculation is identical to the one in samtools rANS_static.c - // Presumably the frequency table (always big enough for order 1) = 257*257, then * 3 for each entry - // (byte->symbol, 2 bytes -> scaled frequency), + 9 for the header (order byte, and 2 int lengths - // for compressed/uncompressed lengths) ? Plus additional 5% for..., for what ??? - final int compressedSize = (int) (1.05 * inSize + 257 * 257 * 3 + 9); - final ByteBuffer outputBuffer = ByteBuffer.allocate(compressedSize); - if (outputBuffer.remaining() < compressedSize) { - throw new RuntimeException("Failed to allocate sufficient buffer size for RANS coder."); - } - outputBuffer.order(ByteOrder.LITTLE_ENDIAN); - return outputBuffer; - } - - private static void writeCompressionPrefix( - final ORDER order, - final ByteBuffer outBuffer, - final int inSize, - final int frequencyTableSize, - final int compressedBlobSize) { - ValidationUtils.validateArg(order == ORDER.ONE || order == ORDER.ZERO,"unrecognized RANS order"); - outBuffer.limit(PREFIX_BYTE_LENGTH + frequencyTableSize + compressedBlobSize); - - // go back to the beginning of the stream and write the prefix values - // write the (ORDER as a single byte at offset 0) - outBuffer.put(0, (byte) (order == ORDER.ZERO ? 0 : 1)); - outBuffer.order(ByteOrder.LITTLE_ENDIAN); - // move past the ORDER and write the compressed size - outBuffer.putInt(ORDER_BYTE_LENGTH, frequencyTableSize + compressedBlobSize); - // move past the compressed size and write the uncompressed size - outBuffer.putInt(ORDER_BYTE_LENGTH + COMPRESSED_BYTE_LENGTH, inSize); - outBuffer.rewind(); - } - -} diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/RANSDecode.java b/src/main/java/htsjdk/samtools/cram/compression/rans/RANSDecode.java new file mode 100644 index 0000000000..154cfa9614 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/rans/RANSDecode.java @@ -0,0 +1,51 @@ +package htsjdk.samtools.cram.compression.rans; + +import java.nio.ByteBuffer; + +public abstract class RANSDecode { + private ArithmeticDecoder[] D; + private RANSDecodingSymbol[][] decodingSymbols; + + // GETTERS + protected ArithmeticDecoder[] getD() { + return D; + } + + protected RANSDecodingSymbol[][] getDecodingSymbols() { + return decodingSymbols; + } + + // This method assumes that inBuffer is already rewound. + // It uncompresses the data in the inBuffer, leaving it consumed. + // Returns a rewound ByteBuffer containing the uncompressed data. + public abstract ByteBuffer uncompress(final ByteBuffer inBuffer); + + // Lazy initialization of working memory for the decoder + protected void initializeRANSDecoder() { + if (D == null) { + D = new ArithmeticDecoder[Constants.NUMBER_OF_SYMBOLS]; + for (int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i++) { + D[i] = new ArithmeticDecoder(); + } + } else { + for (int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i++) { + D[i].reset(); + } + } + if (decodingSymbols == null) { + decodingSymbols = new RANSDecodingSymbol[Constants.NUMBER_OF_SYMBOLS][Constants.NUMBER_OF_SYMBOLS]; + for (int i = 0; i < decodingSymbols.length; i++) { + for (int j = 0; j < decodingSymbols[i].length; j++) { + decodingSymbols[i][j] = new RANSDecodingSymbol(); + } + } + } else { + for (int i = 0; i < decodingSymbols.length; i++) { + for (int j = 0; j < decodingSymbols[i].length; j++) { + decodingSymbols[i][j].set(0, 0); + } + } + } + } + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/RANSDecodingSymbol.java b/src/main/java/htsjdk/samtools/cram/compression/rans/RANSDecodingSymbol.java index 3dde8f8c02..34d0bc7dda 100644 --- a/src/main/java/htsjdk/samtools/cram/compression/rans/RANSDecodingSymbol.java +++ b/src/main/java/htsjdk/samtools/cram/compression/rans/RANSDecodingSymbol.java @@ -26,7 +26,7 @@ import java.nio.ByteBuffer; -final class RANSDecodingSymbol { +final public class RANSDecodingSymbol { int start; // Start of range. int freq; // Symbol frequency. @@ -42,7 +42,7 @@ public void set(final int start, final int freq) { // "start" and frequency "freq". All frequencies are assumed to sum to // "1 << scale_bits". // No renormalization or output happens. - public int advanceSymbolStep(final int r, final int scaleBits) { + public long advanceSymbolStep(final long r, final int scaleBits) { final int mask = ((1 << scaleBits) - 1); // s, x = D(x) @@ -52,22 +52,34 @@ public int advanceSymbolStep(final int r, final int scaleBits) { // Advances in the bit stream by "popping" a single symbol with range start // "start" and frequency "freq". All frequencies are assumed to sum to // "1 << scale_bits". - public int advanceSymbol(final int rIn, final ByteBuffer byteBuffer, final int scaleBits) { + public long advanceSymbol4x8(final long rIn, final ByteBuffer byteBuffer, final int scaleBits) { final int mask = (1 << scaleBits) - 1; // s, x = D(x) - int r = rIn; - r = freq * (r >> scaleBits) + (r & mask) - start; + long ret = freq * (rIn >> scaleBits) + (rIn & mask) - start; // re-normalize - if (r < Constants.RANS_BYTE_L) { + if (ret < Constants.RANS_4x8_LOWER_BOUND) { do { final int b = 0xFF & byteBuffer.get(); - r = (r << 8) | b; - } while (r < Constants.RANS_BYTE_L); + ret = (ret << 8) | b; + } while (ret < Constants.RANS_4x8_LOWER_BOUND); } + return ret; + } + + public long advanceSymbolNx16(final long rIn, final ByteBuffer byteBuffer, final int scaleBits) { + final int mask = (1 << scaleBits) - 1; - return r; + // s, x = D(x) + long ret = freq * (rIn >> scaleBits) + (rIn & mask) - start; + + // re-normalize + if (ret < (Constants.RANS_Nx16_LOWER_BOUND)){ + final int i = (0xFF & byteBuffer.get()) | ((0xFF & byteBuffer.get()) << 8); + ret = (ret << 16) + i; + } + return ret; } -} +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/RANSEncode.java b/src/main/java/htsjdk/samtools/cram/compression/rans/RANSEncode.java new file mode 100644 index 0000000000..49b12dd275 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/rans/RANSEncode.java @@ -0,0 +1,58 @@ +package htsjdk.samtools.cram.compression.rans; + +import java.nio.ByteBuffer; + +public abstract class RANSEncode { + private RANSEncodingSymbol[][] encodingSymbols; + + // Getter + protected RANSEncodingSymbol[][] getEncodingSymbols() { + return encodingSymbols; + } + + // This method assumes that inBuffer is already rewound. + // It compresses the data in the inBuffer, leaving it consumed. + // Returns a rewound ByteBuffer containing the compressed data. + public abstract ByteBuffer compress(final ByteBuffer inBuffer, final T params); + + // Lazy initialization of working memory for the encoder + protected void initializeRANSEncoder() { + if (encodingSymbols == null) { + encodingSymbols = new RANSEncodingSymbol[Constants.NUMBER_OF_SYMBOLS][Constants.NUMBER_OF_SYMBOLS]; + for (int i = 0; i < encodingSymbols.length; i++) { + for (int j = 0; j < encodingSymbols[i].length; j++) { + encodingSymbols[i][j] = new RANSEncodingSymbol(); + } + } + } else { + for (int i = 0; i < encodingSymbols.length; i++) { + for (int j = 0; j < encodingSymbols[i].length; j++) { + encodingSymbols[i][j].reset(); + } + } + } + } + + protected void buildSymsOrder0(final int[] frequencies) { + updateEncodingSymbols(frequencies, getEncodingSymbols()[0]); + } + + protected void buildSymsOrder1(final int[][] frequencies) { + final RANSEncodingSymbol[][] encodingSymbols = getEncodingSymbols(); + for (int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i++) { + updateEncodingSymbols(frequencies[i], encodingSymbols[i]); + } + } + + private void updateEncodingSymbols(int[] frequencies, RANSEncodingSymbol[] encodingSymbols) { + int cumulativeFreq = 0; + for (int symbol = 0; symbol < Constants.NUMBER_OF_SYMBOLS; symbol++) { + if (frequencies[symbol] != 0) { + //For each symbol, set start = cumulative frequency and freq = frequencies[symbol] + encodingSymbols[symbol].set(cumulativeFreq, frequencies[symbol], Constants.TOTAL_FREQ_SHIFT); + cumulativeFreq += frequencies[symbol]; + } + } + } + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/RANSEncodingSymbol.java b/src/main/java/htsjdk/samtools/cram/compression/rans/RANSEncodingSymbol.java index 2d70255416..8188d1a825 100644 --- a/src/main/java/htsjdk/samtools/cram/compression/rans/RANSEncodingSymbol.java +++ b/src/main/java/htsjdk/samtools/cram/compression/rans/RANSEncodingSymbol.java @@ -28,22 +28,23 @@ import java.nio.ByteBuffer; -final class RANSEncodingSymbol { - private int xMax; // (Exclusive) upper bound of pre-normalization interval +public final class RANSEncodingSymbol { + private long xMax; // (Exclusive) upper bound of pre-normalization interval private int rcpFreq; // Fixed-point reciprocal frequency private int bias; // Bias private int cmplFreq; // Complement of frequency: (1 << scaleBits) - freq private int rcpShift; // Reciprocal shift public void reset() { - xMax = rcpFreq = bias = cmplFreq = rcpFreq = 0; + xMax = rcpFreq = bias = cmplFreq = rcpShift = 0; } public void set(final int start, final int freq, final int scaleBits) { - // RansAssert(scale_bits <= 16); RansAssert(start <= (1u << - // scale_bits)); RansAssert(freq <= (1u << scale_bits) - start); - xMax = ((Constants.RANS_BYTE_L >> scaleBits) << 8) * freq; + // Rans4x8: xMax = ((Constants.RANS_BYTE_L_4x8 >> scaleBits) << 8) * freq = (1<< 31-scaleBits) * freq + // RansNx16: xMax = ((Constants.RANS_BYTE_L_Nx16 >> scaleBits) << 16) * freq = (1<< 31-scaleBits) * freq + // why freq > 4095 in Nx16? + xMax = (1L<< (31-scaleBits)) * freq; cmplFreq = (1 << scaleBits) - freq; if (freq < 2) { rcpFreq = (int) ~0L; @@ -56,7 +57,6 @@ public void set(final int start, final int freq, final int scaleBits) { while (freq > (1L << shift)) { shift++; } - rcpFreq = (int) (((1L << (shift + 31)) + freq - 1) / freq); rcpShift = shift - 1; @@ -64,21 +64,47 @@ public void set(final int start, final int freq, final int scaleBits) { // have bias=start. bias = start; } - rcpShift += 32; // Avoid the extra >>32 in RansEncPutSymbol } - public int putSymbol(int r, final ByteBuffer byteBuffer) { + public long putSymbol4x8(final long r, final ByteBuffer byteBuffer) { + ValidationUtils.validateArg(xMax != 0, "can't encode symbol with freq=0"); + + // re-normalize + long retSymbol = r; + if (retSymbol >= xMax) { + byteBuffer.put((byte) (retSymbol & 0xFF)); + retSymbol >>= 8; + if (retSymbol >= xMax) { + byteBuffer.put((byte) (retSymbol & 0xFF)); + retSymbol >>= 8; + } + } + + // x = C(s,x) + // NOTE: written this way so we get a 32-bit "multiply high" when + // available. If you're on a 64-bit platform with cheap multiplies + // (e.g. x64), just bake the +32 into rcp_shift. + // int q = (int) (((uint64_t)x * sym.rcp_freq) >> 32) >> sym.rcp_shift; + + // The extra >>32 has already been added to RansEncSymbolInit + final long q = ((retSymbol * (0xFFFFFFFFL & rcpFreq)) >> rcpShift); + return retSymbol + bias + q * cmplFreq; + } + + public long putSymbolNx16(final long r, final ByteBuffer byteBuffer) { ValidationUtils.validateArg(xMax != 0, "can't encode symbol with freq=0"); // re-normalize - int x = r; - if (x >= xMax) { - byteBuffer.put((byte) (x & 0xFF)); - x >>= 8; - if (x >= xMax) { - byteBuffer.put((byte) (x & 0xFF)); - x >>= 8; + long retSymbol = r; + if (retSymbol >= xMax) { + byteBuffer.put((byte) ((retSymbol>>8) & 0xFF)); // extra line - 1 more byte + byteBuffer.put((byte) (retSymbol & 0xFF)); + retSymbol >>=16; + if (retSymbol >= xMax) { + byteBuffer.put((byte) ((retSymbol>>8) & 0xFF)); // extra line - 1 more byte + byteBuffer.put((byte) (retSymbol & 0xFF)); + retSymbol >>=16; } } @@ -89,8 +115,7 @@ public int putSymbol(int r, final ByteBuffer byteBuffer) { // int q = (int) (((uint64_t)x * sym.rcp_freq) >> 32) >> sym.rcp_shift; // The extra >>32 has already been added to RansEncSymbolInit - final long q = ((x * (0xFFFFFFFFL & rcpFreq)) >> rcpShift); - r = (int) (x + bias + q * cmplFreq); - return r; + final long q = ((retSymbol * (0xFFFFFFFFL & rcpFreq)) >> rcpShift); + return retSymbol + bias + q * cmplFreq; } -} +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/RANSParams.java b/src/main/java/htsjdk/samtools/cram/compression/rans/RANSParams.java new file mode 100644 index 0000000000..7d617d5249 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/rans/RANSParams.java @@ -0,0 +1,21 @@ +package htsjdk.samtools.cram.compression.rans; + +public interface RANSParams { + + enum ORDER { + ZERO, ONE; + + public static ORDER fromInt(final int orderValue) { + try { + return ORDER.values()[orderValue]; + } catch (final ArrayIndexOutOfBoundsException e) { + throw new IllegalArgumentException("Unknown rANS order: " + orderValue, e); + } + } + } + + int getFormatFlags(); + + ORDER getOrder(); + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/Utils.java b/src/main/java/htsjdk/samtools/cram/compression/rans/Utils.java index d2da830eb5..06abbca89d 100644 --- a/src/main/java/htsjdk/samtools/cram/compression/rans/Utils.java +++ b/src/main/java/htsjdk/samtools/cram/compression/rans/Utils.java @@ -2,7 +2,7 @@ import java.nio.ByteBuffer; -final class Utils { +final public class Utils { private static void reverse(final byte[] array, final int offset, final int size) { if (array == null) { @@ -10,9 +10,8 @@ private static void reverse(final byte[] array, final int offset, final int size } int i = offset; int j = offset + size - 1; - byte tmp; while (j > i) { - tmp = array[j]; + byte tmp = array[j]; array[j] = array[i]; array[i] = tmp; j--; @@ -20,33 +19,151 @@ private static void reverse(final byte[] array, final int offset, final int size } } - static void reverse(final ByteBuffer byteBuffer) { - byte tmp; + public static void reverse(final ByteBuffer byteBuffer) { if (byteBuffer.hasArray()) { reverse(byteBuffer.array(), byteBuffer.arrayOffset(), byteBuffer.limit()); } else { for (int i = 0; i < byteBuffer.limit(); i++) { - tmp = byteBuffer.get(i); byteBuffer.put(i, byteBuffer.get(byteBuffer.limit() - i - 1)); - byteBuffer.put(byteBuffer.limit() - i - 1, tmp); + byteBuffer.put(byteBuffer.limit() - i - 1, byteBuffer.get(i)); } } } // Returns the current cumulative frequency (map it to a symbol yourself!) - static int RANSDecodeGet(final int r, final int scaleBits) { - return r & ((1 << scaleBits) - 1); + public static int RANSGetCumulativeFrequency(final long r, final int scaleBits) { + return (int) (r & ((1 << scaleBits) - 1)); // since cumulative frequency will be a maximum of 4096 } - // Re-normalize. - static int RANSDecodeRenormalize(int r, final ByteBuffer byteBuffer) { - // re-normalize - if (r < Constants.RANS_BYTE_L) { - do { - r = (r << 8) | (0xFF & byteBuffer.get()); - } while (r < Constants.RANS_BYTE_L); + public static long RANSDecodeRenormalize4x8(final long r, final ByteBuffer byteBuffer) { + long ret = r; + while (ret < Constants.RANS_4x8_LOWER_BOUND) { + ret = (ret << 8) | (0xFF & byteBuffer.get()); } + return ret; + } - return r; + public static long RANSDecodeRenormalizeNx16(final long r, final ByteBuffer byteBuffer) { + long ret = r; + if (ret < (Constants.RANS_Nx16_LOWER_BOUND)) { + final int i = (0xFF & byteBuffer.get()) | ((0xFF & byteBuffer.get()) << 8); + ret = (ret << 16) | i; + } + return ret; } -} + + public static void normaliseFrequenciesOrder0(final int[] F, final int bits) { + // Returns an array of normalised Frequencies, + // such that the frequencies add up to 1<0)?(((long) (renormFreq) << 31) / T + (1 << 30) / T):0; + int fsum = 0; + for (int symbol = 0; symbol < Constants.NUMBER_OF_SYMBOLS; symbol++) { + if (F[symbol] == 0) { + continue; + } + + // As per spec, total frequencies after normalization should be 4096 (4095 could be considered legacy value) + // using tr to normalize symbol frequencies such that their total = renormFreq + if ((F[symbol] = (int) ((F[symbol] * tr) >> 31)) == 0) { + + // A non-zero symbol frequency should not be incorrectly set to 0. + // If the calculated value is 0, change it to 1 + F[symbol] = 1; + } + fsum += F[symbol]; + } + + // adjust the frequency of the symbol "M" such that + // the sum of frequencies of all the symbols = renormFreq + if (fsum < renormFreq) { + F[M] += renormFreq - fsum; + } else if (fsum > renormFreq) { + F[M] -= fsum - renormFreq; + } + } + + public static void normaliseFrequenciesOrder1(final int[][] F, final int shift) { + // calculate the minimum bit size required for representing the frequency array for each symbol + // and normalise the frequency array using the calculated bit size + for (int j = 0; j < Constants.NUMBER_OF_SYMBOLS; j++) { + if (F[Constants.NUMBER_OF_SYMBOLS][j]==0){ + continue; + } + + // log2 N = Math.log(N)/Math.log(2) + int bitSize = (int) Math.ceil(Math.log(F[Constants.NUMBER_OF_SYMBOLS][j]) / Math.log(2)); + if (bitSize > shift) + bitSize = shift; + + // TODO: check if handling bitSize = 0 is required + if (bitSize == 0) + bitSize = 1; // bitSize cannot be zero + + // special case -> if a symbol occurs only once and at the end of the input, + // then the order 0 freq table associated with it should have a frequency of 1 for symbol 0 + // i.e, F[sym][0] = 1 + normaliseFrequenciesOrder0(F[j], bitSize); + } + } + + public static void normaliseFrequenciesOrder0Shift(final int[] frequencies, final int bits){ + + // compute total frequency + int totalFrequency = 0; + for (int freq : frequencies) { + totalFrequency += freq; + } + if (totalFrequency == 0 || totalFrequency == (1<> 2; + int i0 = 0; + int i1 = isz4; + int i2 = 2 * isz4; + int i7 = 3 * isz4; + byte l0 = 0; + byte l1 = 0; + byte l2 = 0; + byte l7 = 0; + final ArithmeticDecoder[] D = getD(); + final RANSDecodingSymbol[][] syms = getDecodingSymbols(); + for (; i0 < isz4; i0++, i1++, i2++, i7++) { + final byte c0 = D[0xFF & l0].reverseLookup[Utils.RANSGetCumulativeFrequency(rans0, Constants.TOTAL_FREQ_SHIFT)]; + final byte c1 = D[0xFF & l1].reverseLookup[Utils.RANSGetCumulativeFrequency(rans1, Constants.TOTAL_FREQ_SHIFT)]; + final byte c2 = D[0xFF & l2].reverseLookup[Utils.RANSGetCumulativeFrequency(rans2, Constants.TOTAL_FREQ_SHIFT)]; + final byte c7 = D[0xFF & l7].reverseLookup[Utils.RANSGetCumulativeFrequency(rans7, Constants.TOTAL_FREQ_SHIFT)]; + + outBuffer.put(i0, c0); + outBuffer.put(i1, c1); + outBuffer.put(i2, c2); + outBuffer.put(i7, c7); + + rans0 = syms[0xFF & l0][0xFF & c0].advanceSymbolStep(rans0, Constants.TOTAL_FREQ_SHIFT); + rans1 = syms[0xFF & l1][0xFF & c1].advanceSymbolStep(rans1, Constants.TOTAL_FREQ_SHIFT); + rans2 = syms[0xFF & l2][0xFF & c2].advanceSymbolStep(rans2, Constants.TOTAL_FREQ_SHIFT); + rans7 = syms[0xFF & l7][0xFF & c7].advanceSymbolStep(rans7, Constants.TOTAL_FREQ_SHIFT); + + rans0 = Utils.RANSDecodeRenormalize4x8(rans0, inBuffer); + rans1 = Utils.RANSDecodeRenormalize4x8(rans1, inBuffer); + rans2 = Utils.RANSDecodeRenormalize4x8(rans2, inBuffer); + rans7 = Utils.RANSDecodeRenormalize4x8(rans7, inBuffer); + + l0 = c0; + l1 = c1; + l2 = c2; + l7 = c7; + } + + // Remainder + for (; i7 < out_sz; i7++) { + final byte c7 = D[0xFF & l7].reverseLookup[Utils.RANSGetCumulativeFrequency(rans7, Constants.TOTAL_FREQ_SHIFT)]; + outBuffer.put(i7, c7); + rans7 = syms[0xFF & l7][0xFF & c7].advanceSymbol4x8(rans7, inBuffer, Constants.TOTAL_FREQ_SHIFT); + // TODO: the spec specifies renormalize here + // rans7 = Utils.RANSDecodeRenormalize4x8(rans7, inBuffer); + l7 = c7; + } + } + + private void readStatsOrder0(final ByteBuffer cp) { + // Pre-compute reverse lookup of frequency. + final ArithmeticDecoder decoder = getD()[0]; + final RANSDecodingSymbol[] decodingSymbols = getDecodingSymbols()[0]; + int rle = 0; + int cumulativeFrequency = 0; + int symbol = cp.get() & 0xFF; + do { + if ((decoder.frequencies[symbol] = (cp.get() & 0xFF)) >= 0x80) { + + // read a variable sized unsigned integer with ITF8 encoding + decoder.frequencies[symbol] &= ~0x80; + decoder.frequencies[symbol] = ((decoder.frequencies[symbol] & 0x7F) << 8) | (cp.get() & 0xFF); + } + + decodingSymbols[symbol].set(cumulativeFrequency, decoder.frequencies[symbol]); + + /* Build reverse lookup table */ + Arrays.fill(decoder.reverseLookup, cumulativeFrequency, cumulativeFrequency + decoder.frequencies[symbol], (byte) symbol); + + cumulativeFrequency += decoder.frequencies[symbol]; + + if (rle == 0 && symbol + 1 == (0xFF & cp.get(cp.position()))) { + symbol = cp.get() & 0xFF; + rle = cp.get() & 0xFF; + } else if (rle != 0) { + rle--; + symbol++; + } else { + symbol = cp.get() & 0xFF; + } + } while (symbol != 0); + + assert (cumulativeFrequency <= Constants.TOTAL_FREQ); + } + + private void readStatsOrder1(final ByteBuffer cp) { + final ArithmeticDecoder[] D = getD(); + final RANSDecodingSymbol[][] decodingSymbols = getDecodingSymbols(); + int rle_i = 0; + int i = 0xFF & cp.get(); + do { + int rle_j = 0; + int cumulativeFrequency = 0; + int j = 0xFF & cp.get(); + do { + if ((D[i].frequencies[j] = (0xFF & cp.get())) >= 0x80) { + + // read a variable sized unsigned integer with ITF8 encoding + D[i].frequencies[j] &= ~0x80; + D[i].frequencies[j] = ((D[i].frequencies[j] & 0x7F) << 8) | (0xFF & cp.get()); + } + + if (D[i].frequencies[j] == 0) { + D[i].frequencies[j] = Constants.TOTAL_FREQ; + } + + decodingSymbols[i][j].set( + cumulativeFrequency, + D[i].frequencies[j] + ); + + /* Build reverse lookup table */ + Arrays.fill(D[i].reverseLookup, cumulativeFrequency, cumulativeFrequency + D[i].frequencies[j], (byte) j); + + cumulativeFrequency += D[i].frequencies[j]; + assert (cumulativeFrequency <= Constants.TOTAL_FREQ); + + if (rle_j == 0 && j + 1 == (0xFF & cp.get(cp.position()))) { + j = (0xFF & cp.get()); + rle_j = (0xFF & cp.get()); + } else if (rle_j != 0) { + rle_j--; + j++; + } else { + j = (0xFF & cp.get()); + } + } while (j != 0); + + if (rle_i == 0 && i + 1 == (0xFF & cp.get(cp.position()))) { + i = (0xFF & cp.get()); + rle_i = (0xFF & cp.get()); + } else if (rle_i != 0) { + rle_i--; + i++; + } else { + i = (0xFF & cp.get()); + } + } while (i != 0); + } + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/rans4x8/RANS4x8Encode.java b/src/main/java/htsjdk/samtools/cram/compression/rans/rans4x8/RANS4x8Encode.java new file mode 100644 index 0000000000..638882fb67 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/rans/rans4x8/RANS4x8Encode.java @@ -0,0 +1,442 @@ +package htsjdk.samtools.cram.compression.rans.rans4x8; + +import htsjdk.samtools.cram.CRAMException; +import htsjdk.samtools.cram.compression.CompressionUtils; +import htsjdk.samtools.cram.compression.rans.Constants; +import htsjdk.samtools.cram.compression.rans.RANSEncode; +import htsjdk.samtools.cram.compression.rans.RANSEncodingSymbol; +import htsjdk.samtools.cram.compression.rans.RANSParams; +import htsjdk.samtools.cram.compression.rans.Utils; +import htsjdk.utils.ValidationUtils; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +public class RANS4x8Encode extends RANSEncode { + + // streams smaller than this value don't have sufficient symbol context for ORDER-1 encoding, + // so always use ORDER-0 + private static final int MINIMUM_ORDER_1_SIZE = 4; + private static final ByteBuffer EMPTY_BUFFER = CompressionUtils.allocateByteBuffer(0); + + // This method assumes that inBuffer is already rewound. + // It compresses the data in the inBuffer, leaving it consumed. + // Returns a rewound ByteBuffer containing the compressed data. + public ByteBuffer compress(final ByteBuffer inBuffer, final RANS4x8Params params) { + if (inBuffer.remaining() == 0) { + return EMPTY_BUFFER; + } + initializeRANSEncoder(); + if (inBuffer.remaining() < MINIMUM_ORDER_1_SIZE) { + // ORDER-1 encoding of less than 4 bytes is not permitted, so just use ORDER-0 + return compressOrder0Way4(inBuffer); + } + final RANSParams.ORDER order= params.getOrder(); + switch (order) { + case ZERO: + return compressOrder0Way4(inBuffer); + + case ONE: + return compressOrder1Way4(inBuffer); + + default: + throw new CRAMException("Unknown rANS order: " + params.getOrder()); + } + } + + private ByteBuffer compressOrder0Way4(final ByteBuffer inBuffer) { + final int inputSize = inBuffer.remaining(); + final ByteBuffer outBuffer = CompressionUtils.allocateOutputBuffer(inputSize); + + // move the output buffer ahead to the start of the frequency table (we'll come back and + // write the output stream prefix at the end of this method) + outBuffer.position(Constants.RANS_4x8_PREFIX_BYTE_LENGTH); // start of frequency table + + // get the normalised frequencies of the alphabets + final int[] normalizedFreq = calcFrequenciesOrder0(inBuffer); + + // using the normalised frequencies, set the RANSEncodingSymbols + buildSymsOrder0(normalizedFreq); + final ByteBuffer cp = CompressionUtils.slice(outBuffer); + + // write Frequency table + final int frequencyTableSize = writeFrequenciesOrder0(cp, normalizedFreq); + + inBuffer.rewind(); + + final RANSEncodingSymbol[] syms = getEncodingSymbols()[0]; + final int in_size = inBuffer.remaining(); + long rans0, rans1, rans2, rans3; + final ByteBuffer ptr = CompressionUtils.slice(cp); + rans0 = Constants.RANS_4x8_LOWER_BOUND; + rans1 = Constants.RANS_4x8_LOWER_BOUND; + rans2 = Constants.RANS_4x8_LOWER_BOUND; + rans3 = Constants.RANS_4x8_LOWER_BOUND; + + int i; + switch (i = (in_size & 3)) { + case 3: + rans2 = syms[0xFF & inBuffer.get(in_size - (i - 2))].putSymbol4x8(rans2, ptr); + case 2: + rans1 = syms[0xFF & inBuffer.get(in_size - (i - 1))].putSymbol4x8(rans1, ptr); + case 1: + rans0 = syms[0xFF & inBuffer.get(in_size - (i))].putSymbol4x8(rans0, ptr); + case 0: + break; + } + for (i = (in_size & ~3); i > 0; i -= 4) { + final byte c3 = inBuffer.get(i - 1); + final byte c2 = inBuffer.get(i - 2); + final byte c1 = inBuffer.get(i - 3); + final byte c0 = inBuffer.get(i - 4); + + rans3 = syms[0xFF & c3].putSymbol4x8(rans3, ptr); + rans2 = syms[0xFF & c2].putSymbol4x8(rans2, ptr); + rans1 = syms[0xFF & c1].putSymbol4x8(rans1, ptr); + rans0 = syms[0xFF & c0].putSymbol4x8(rans0, ptr); + } + + ptr.order(ByteOrder.BIG_ENDIAN); + ptr.putInt((int) rans3); + ptr.putInt((int) rans2); + ptr.putInt((int) rans1); + ptr.putInt((int) rans0); + ptr.flip(); + final int cdata_size = ptr.limit(); + // reverse the compressed bytes, so that they become in REVERSE order: + Utils.reverse(ptr); + inBuffer.position(inBuffer.limit()); + + // write the prefix at the beginning of the output buffer + writeCompressionPrefix(RANSParams.ORDER.ZERO, outBuffer, inputSize, frequencyTableSize, cdata_size); + return outBuffer; + } + + private ByteBuffer compressOrder1Way4(final ByteBuffer inBuffer) { + final int inSize = inBuffer.remaining(); + final ByteBuffer outBuffer = CompressionUtils.allocateOutputBuffer(inSize); + + // move to start of frequency + outBuffer.position(Constants.RANS_4x8_PREFIX_BYTE_LENGTH); + + // get normalized frequencies + final int[][] normalizedFreq = calcFrequenciesOrder1(inBuffer); + + // using the normalised frequencies, set the RANSEncodingSymbols + buildSymsOrder1(normalizedFreq); + + final ByteBuffer cp = CompressionUtils.slice(outBuffer); + final int frequencyTableSize = writeFrequenciesOrder1(cp, normalizedFreq); + inBuffer.rewind(); + final int in_size = inBuffer.remaining(); + long rans0, rans1, rans2, rans3; + rans0 = Constants.RANS_4x8_LOWER_BOUND; + rans1 = Constants.RANS_4x8_LOWER_BOUND; + rans2 = Constants.RANS_4x8_LOWER_BOUND; + rans3 = Constants.RANS_4x8_LOWER_BOUND; + + final int isz4 = in_size >> 2; + int i0 = isz4 - 2; + int i1 = 2 * isz4 - 2; + int i2 = 3 * isz4 - 2; + int i3 = 4 * isz4 - 2; + + byte l0 = 0; + if (i0 + 1 >= 0) { + l0 = inBuffer.get(i0 + 1); + } + byte l1 = 0; + if (i1 + 1 >= 0) { + l1 = inBuffer.get(i1 + 1); + } + byte l2 = 0; + if (i2 + 1 >= 0) { + l2 = inBuffer.get(i2 + 1); + } + + // Deal with the remainder + byte l3 = inBuffer.get(in_size - 1); + + // Slicing is needed for buffer reversing later + final ByteBuffer ptr = CompressionUtils.slice(cp); + final RANSEncodingSymbol[][] syms = getEncodingSymbols(); + for (i3 = in_size - 2; i3 > 4 * isz4 - 2 && i3 >= 0; i3--) { + final byte c3 = inBuffer.get(i3); + rans3 = syms[0xFF & c3][0xFF & l3].putSymbol4x8(rans3, ptr); + l3 = c3; + } + + for (; i0 >= 0; i0--, i1--, i2--, i3--) { + final byte c0 = inBuffer.get(i0); + final byte c1 = inBuffer.get(i1); + final byte c2 = inBuffer.get(i2); + final byte c3 = inBuffer.get(i3); + + rans3 = syms[0xFF & c3][0xFF & l3].putSymbol4x8(rans3, ptr); + rans2 = syms[0xFF & c2][0xFF & l2].putSymbol4x8(rans2, ptr); + rans1 = syms[0xFF & c1][0xFF & l1].putSymbol4x8(rans1, ptr); + rans0 = syms[0xFF & c0][0xFF & l0].putSymbol4x8(rans0, ptr); + + l0 = c0; + l1 = c1; + l2 = c2; + l3 = c3; + } + + rans3 = syms[0][0xFF & l3].putSymbol4x8(rans3, ptr); + rans2 = syms[0][0xFF & l2].putSymbol4x8(rans2, ptr); + rans1 = syms[0][0xFF & l1].putSymbol4x8(rans1, ptr); + rans0 = syms[0][0xFF & l0].putSymbol4x8(rans0, ptr); + + ptr.order(ByteOrder.BIG_ENDIAN); + ptr.putInt((int) rans3); + ptr.putInt((int) rans2); + ptr.putInt((int) rans1); + ptr.putInt((int) rans0); + ptr.flip(); + final int compressedBlobSize = ptr.limit(); + Utils.reverse(ptr); + /* + * Depletion of the in buffer cannot be confirmed because of the get(int + * position) method use during encoding, hence enforcing: + */ + inBuffer.position(inBuffer.limit()); + + // write the prefix at the beginning of the output buffer + writeCompressionPrefix(RANSParams.ORDER.ONE, outBuffer, inSize, frequencyTableSize, compressedBlobSize); + return outBuffer; + } + + private static void writeCompressionPrefix( + final RANSParams.ORDER order, + final ByteBuffer outBuffer, + final int inSize, + final int frequencyTableSize, + final int compressedBlobSize) { + ValidationUtils.validateArg(order == RANSParams.ORDER.ONE || order == RANSParams.ORDER.ZERO,"unrecognized RANS order"); + outBuffer.limit(Constants.RANS_4x8_PREFIX_BYTE_LENGTH + frequencyTableSize + compressedBlobSize); + + // go back to the beginning of the stream and write the prefix values + // write the (ORDER as a single byte at offset 0) + outBuffer.put(0, (byte) (order == RANSParams.ORDER.ZERO ? 0 : 1)); + // move past the ORDER and write the compressed size + outBuffer.putInt(Constants.RANS_4x8_ORDER_BYTE_LENGTH, frequencyTableSize + compressedBlobSize); + // move past the compressed size and write the uncompressed size + outBuffer.putInt(Constants.RANS_4x8_ORDER_BYTE_LENGTH + Constants.RANS_4x8_COMPRESSED_BYTE_LENGTH, inSize); + outBuffer.rewind(); + } + + private static int[] calcFrequenciesOrder0(final ByteBuffer inBuffer) { + // TODO: remove duplicate code -use Utils.normalise here + final int T = inBuffer.remaining(); + + // Compute statistics + // T = total of true counts = inBuffer size + // F = scaled integer frequencies + // M = sum(fs) + final int[] F = new int[Constants.NUMBER_OF_SYMBOLS]; + for (int i = 0; i < T; i++) { + F[0xFF & inBuffer.get()]++; + } + + // Normalise so T == TOTFREQ + // m is the maximum frequency value + // M is the symbol that has the maximum frequency + int m = 0; + int M = 0; + for (int j = 0; j < Constants.NUMBER_OF_SYMBOLS; j++) { + if (m < F[j]) { + m = F[j]; + M = j; + } + } + + final long tr = ((long) Constants.TOTAL_FREQ << 31) / T + (1 << 30) / T; + int fsum = 0; + for (int j = 0; j < Constants.NUMBER_OF_SYMBOLS; j++) { + if (F[j] == 0) { + continue; + } + // using tr to normalize symbol frequencies such that their total = (1<<12) = 4096 + if ((F[j] = (int) ((F[j] * tr) >> 31)) == 0) { + // make sure that a non-zero symbol frequency is not incorrectly set to 0. + // Change it to 1 if the calculated value is 0. + F[j] = 1; + } + fsum += F[j]; + } + + // Commenting the below line as it is incrementing fsum by 1, which does not make sense + // and it also makes total normalised frequency = 4095 and not 4096. + // fsum++; + + // adjust the frequency of the symbol with maximum frequency to make sure that + // the sum of frequencies of all the symbols = 4096 + if (fsum < Constants.TOTAL_FREQ) { + F[M] += Constants.TOTAL_FREQ - fsum; + } else { + F[M] -= fsum - Constants.TOTAL_FREQ; + } + return F; + } + + private static int[][] calcFrequenciesOrder1(final ByteBuffer in) { + final int in_size = in.remaining(); + + final int[][] F = new int[Constants.NUMBER_OF_SYMBOLS][Constants.NUMBER_OF_SYMBOLS]; + final int[] T = new int[Constants.NUMBER_OF_SYMBOLS]; + int last_i = 0; + for (int i = 0; i < in_size; i++) { + int c = 0xFF & in.get(); + F[last_i][c]++; + T[last_i]++; + last_i = c; + } + F[0][0xFF & in.get((in_size >> 2))]++; + F[0][0xFF & in.get(2 * (in_size >> 2))]++; + F[0][0xFF & in.get(3 * (in_size >> 2))]++; + T[0] += 3; + + for (int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i++) { + if (T[i] == 0) { + continue; + } + + final double p = ((double) Constants.TOTAL_FREQ) / T[i]; + int t2 = 0, m = 0, M = 0; + for (int j = 0; j < Constants.NUMBER_OF_SYMBOLS; j++) { + if (F[i][j] == 0) + continue; + + if (m < F[i][j]) { + m = F[i][j]; + M = j; + } + + if ((F[i][j] *= p) == 0) + F[i][j] = 1; + t2 += F[i][j]; + } + + // Commenting the below line as it is incrementing t2 by 1, which does not make sense + // and it also makes total normalised frequency = 4095 and not 4096. + // t2++; + + if (t2 < Constants.TOTAL_FREQ) { + F[i][M] += Constants.TOTAL_FREQ - t2; + } else { + F[i][M] -= t2 - Constants.TOTAL_FREQ; + } + } + + return F; + } + + private static int writeFrequenciesOrder0(final ByteBuffer cp, final int[] F) { + final int start = cp.position(); + + int rle = 0; + for (int j = 0; j < Constants.NUMBER_OF_SYMBOLS; j++) { + if (F[j] != 0) { + // j + if (rle != 0) { + rle--; + } else { + // write the symbol if it is the first symbol or if rle = 0. + // if rle != 0, then skip writing the symbol. + cp.put((byte) j); + // We've encoded two symbol frequencies in a row. + // How many more are there? Store that count so + // we can avoid writing consecutive symbols. + // Note: maximum possible rle = 254 + // rle requires atmost 1 byte + if (rle == 0 && j != 0 && F[j - 1] != 0) { + for (rle = j + 1; rle < Constants.NUMBER_OF_SYMBOLS && F[rle] != 0; rle++) + ; + rle -= j + 1; + cp.put((byte) rle); + } + } + + // F[j] + if (F[j] < 128) { + cp.put((byte) (F[j])); + } else { + // if F[j] >127, it is written in 2 bytes + cp.put((byte) (128 | (F[j] >> 8))); + cp.put((byte) (F[j] & 0xff)); + } + } + } + + // write 0 indicating the end of frequency table + cp.put((byte) 0); + return cp.position() - start; + } + + private static int writeFrequenciesOrder1(final ByteBuffer cp, final int[][] F) { + final int start = cp.position(); + final int[] T = new int[Constants.NUMBER_OF_SYMBOLS]; + + for (int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i++) { + for (int j = 0; j < Constants.NUMBER_OF_SYMBOLS; j++) { + T[i] += F[i][j]; + } + } + + int rle_i = 0; + for (int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i++) { + if (T[i] == 0) { + continue; + } + + // Store frequency table + // i + if (rle_i != 0) { + rle_i--; + } else { + cp.put((byte) i); + // FIXME: could use order-0 statistics to observe which alphabet + // symbols are present and base RLE on that ordering instead. + if (i != 0 && T[i - 1] != 0) { + for (rle_i = i + 1; rle_i < Constants.NUMBER_OF_SYMBOLS && T[rle_i] != 0; rle_i++) + ; + rle_i -= i + 1; + cp.put((byte) rle_i); + } + } + + final int[] F_i_ = F[i]; + int rle_j = 0; + for (int j = 0; j < Constants.NUMBER_OF_SYMBOLS; j++) { + if (F_i_[j] != 0) { + + // j + if (rle_j != 0) { + rle_j--; + } else { + cp.put((byte) j); + if (rle_j == 0 && j != 0 && F_i_[j - 1] != 0) { + for (rle_j = j + 1; rle_j < Constants.NUMBER_OF_SYMBOLS && F_i_[rle_j] != 0; rle_j++) + ; + rle_j -= j + 1; + cp.put((byte) rle_j); + } + } + + // F_i_[j] + if (F_i_[j] < 128) { + cp.put((byte) F_i_[j]); + } else { + cp.put((byte) (128 | (F_i_[j] >> 8))); + cp.put((byte) (F_i_[j] & 0xff)); + } + } + } + cp.put((byte) 0); + } + cp.put((byte) 0); + + return cp.position() - start; + } + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/rans4x8/RANS4x8Params.java b/src/main/java/htsjdk/samtools/cram/compression/rans/rans4x8/RANS4x8Params.java new file mode 100644 index 0000000000..8ea6c9e855 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/rans/rans4x8/RANS4x8Params.java @@ -0,0 +1,30 @@ +package htsjdk.samtools.cram.compression.rans.rans4x8; + +import htsjdk.samtools.cram.compression.rans.RANSParams; +import htsjdk.samtools.cram.compression.rans.ransnx16.RANSNx16Params; + +public class RANS4x8Params implements RANSParams { + + private final ORDER order; + + public RANS4x8Params(final ORDER order) { + this.order = order; + } + + @Override + public String toString() { + return "RANS4x8Params{" + "order=" + order + "}"; + } + + @Override + public ORDER getOrder() { + return order; + } + + public int getFormatFlags(){ + return order == ORDER.ONE ? + RANSNx16Params.ORDER_FLAG_MASK : + 0; + } + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/ransnx16/RANSNx16Decode.java b/src/main/java/htsjdk/samtools/cram/compression/rans/ransnx16/RANSNx16Decode.java new file mode 100644 index 0000000000..9cf18cae13 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/rans/ransnx16/RANSNx16Decode.java @@ -0,0 +1,442 @@ +package htsjdk.samtools.cram.compression.rans.ransnx16; + +import htsjdk.samtools.cram.CRAMException; +import htsjdk.samtools.cram.compression.CompressionUtils; +import htsjdk.samtools.cram.compression.rans.ArithmeticDecoder; +import htsjdk.samtools.cram.compression.rans.Constants; +import htsjdk.samtools.cram.compression.rans.RANSDecode; +import htsjdk.samtools.cram.compression.rans.RANSDecodingSymbol; +import htsjdk.samtools.cram.compression.rans.Utils; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; + +public class RANSNx16Decode extends RANSDecode { + private static final ByteBuffer EMPTY_BUFFER = CompressionUtils.allocateByteBuffer(0); + private static final int FREQ_TABLE_OPTIONALLY_COMPRESSED_MASK = 0x01; + private static final int RLE_META_OPTIONALLY_COMPRESSED_MASK = 0x01; + + // This method assumes that inBuffer is already rewound. + // It uncompresses the data in the inBuffer, leaving it consumed. + // Returns a rewound ByteBuffer containing the uncompressed data. + public ByteBuffer uncompress(final ByteBuffer inBuffer) { + + // For RANS decoding, the bytes are read in little endian from the input stream + inBuffer.order(ByteOrder.LITTLE_ENDIAN); + return uncompress(inBuffer, 0); + } + + private ByteBuffer uncompress(final ByteBuffer inBuffer, final int outSize) { + if (inBuffer.remaining() == 0) { + return EMPTY_BUFFER; + } + + // the first byte of compressed stream gives the formatFlags + final int formatFlags = inBuffer.get() & 0xFF; + final RANSNx16Params ransNx16Params = new RANSNx16Params(formatFlags); + + // if nosz flag is set, then uncompressed size is not recorded. + int uncompressedSize = ransNx16Params.isNosz() ? outSize : CompressionUtils.readUint7(inBuffer); + + // if stripe, then decodeStripe + if (ransNx16Params.isStripe()) { + return decodeStripe(inBuffer, uncompressedSize); + } + + // if pack, get pack metadata, which will be used later to decode packed data + int packDataLength = 0; + int numSymbols = 0; + byte[] packMappingTable = null; + if (ransNx16Params.isPack()) { + packDataLength = uncompressedSize; + numSymbols = inBuffer.get() & 0xFF; + + // if (numSymbols > 16 or numSymbols==0), raise exception + if (numSymbols <= 16 && numSymbols != 0) { + packMappingTable = new byte[numSymbols]; + for (int i = 0; i < numSymbols; i++) { + packMappingTable[i] = inBuffer.get(); + } + uncompressedSize = CompressionUtils.readUint7(inBuffer); + } else { + throw new CRAMException("Bit Packing is not permitted when number of distinct symbols is greater than 16 or equal to 0. " + + "Number of distinct symbols: " + numSymbols); + } + } + + // if rle, get rle metadata, which will be used later to decode rle + int uncompressedRLEOutputLength = 0; + int[] rleSymbols = null; + ByteBuffer uncompressedRLEMetaData = null; + if (ransNx16Params.isRLE()) { + rleSymbols = new int[Constants.NUMBER_OF_SYMBOLS]; + final int uncompressedRLEMetaDataLength = CompressionUtils.readUint7(inBuffer); + uncompressedRLEOutputLength = uncompressedSize; + uncompressedSize = CompressionUtils.readUint7(inBuffer); + uncompressedRLEMetaData = decodeRLEMeta(inBuffer, uncompressedRLEMetaDataLength, rleSymbols, ransNx16Params); + } + + ByteBuffer outBuffer; + + // If CAT is set then, the input is uncompressed + if (ransNx16Params.isCAT()) { + outBuffer = CompressionUtils.slice(inBuffer); + outBuffer.limit(uncompressedSize); + // While resetting the position to the end is not strictly necessary, + // it is being done for the sake of completeness and + // to meet the requirements of the tests that verify the boundary conditions. + inBuffer.position(inBuffer.position()+uncompressedSize); + } else { + outBuffer = CompressionUtils.allocateByteBuffer(uncompressedSize); + + // uncompressedSize is 0 in cases where Pack flag is used + // and number of distinct symbols in the raw data is 1 + if (uncompressedSize != 0) { + switch (ransNx16Params.getOrder()) { + case ZERO: + uncompressOrder0WayN(inBuffer, outBuffer, uncompressedSize, ransNx16Params); + break; + case ONE: + uncompressOrder1WayN(inBuffer, outBuffer, ransNx16Params); + break; + default: + throw new CRAMException("Unknown rANS order: " + ransNx16Params.getOrder()); + } + } + } + + // if rle, then decodeRLE + if (ransNx16Params.isRLE()) { + outBuffer = decodeRLE(outBuffer, rleSymbols, uncompressedRLEMetaData, uncompressedRLEOutputLength); + } + + // if pack, then decodePack + if (ransNx16Params.isPack()) { + outBuffer = CompressionUtils.decodePack(outBuffer, packMappingTable, numSymbols, packDataLength); + } + return outBuffer; + } + + private void uncompressOrder0WayN( + final ByteBuffer inBuffer, + final ByteBuffer outBuffer, + final int outSize, + final RANSNx16Params ransNx16Params) { + initializeRANSDecoder(); + + // read the frequency table, get the normalised frequencies and use it to set the RANSDecodingSymbols + readFrequencyTableOrder0(inBuffer); + + // uncompress using Nway rans states + final int Nway = ransNx16Params.getNumInterleavedRANSStates(); + + // Nway parallel rans states. Nway = 4 or 32 + final long[] rans = new long[Nway]; + + for (int r=0; r> 2) : (outSize >> 5); + + // Number of elements that don't fall into the Nway streams + int remSize = outSize - (interleaveSize * Nway); + final int out_end = outSize - remSize; + final ArithmeticDecoder D = getD()[0]; + final RANSDecodingSymbol[] syms = getDecodingSymbols()[0]; + for (int i = 0; i < out_end; i += Nway) { + for (int r=0; r0){ + byte remainingSymbol = D.reverseLookup[Utils.RANSGetCumulativeFrequency(rans[reverseIndex], Constants.TOTAL_FREQ_SHIFT)]; + syms[0xFF & remainingSymbol].advanceSymbolNx16(rans[reverseIndex], inBuffer, Constants.TOTAL_FREQ_SHIFT); + outBuffer.put(remainingSymbol); + remSize --; + reverseIndex ++; + } + outBuffer.rewind(); + } + + private void uncompressOrder1WayN( + final ByteBuffer inBuffer, + final ByteBuffer outBuffer, + final RANSNx16Params ransNx16Params) { + + // read the first byte + final int frequencyTableFirstByte = (inBuffer.get() & 0xFF); + final boolean optionalCompressFlag = ((frequencyTableFirstByte & FREQ_TABLE_OPTIONALLY_COMPRESSED_MASK)!=0); + final ByteBuffer freqTableSource; + if (optionalCompressFlag) { + + // spec: The order-1 frequency table itself may still be quite large, + // so is optionally compressed using the order-0 rANSNx16 codec with a fixed 4-way interleaving. + + // if optionalCompressFlag is true, the frequency table was compressed using RANS Nx16, N=4 Order 0 + final int uncompressedLength = CompressionUtils.readUint7(inBuffer); + final int compressedLength = CompressionUtils.readUint7(inBuffer); + byte[] compressedFreqTable = new byte[compressedLength]; + + // read compressedLength bytes into compressedFreqTable byte array + inBuffer.get(compressedFreqTable,0,compressedLength); + + // decode the compressedFreqTable to get the uncompressedFreqTable using RANS Nx16, N=4 Order 0 uncompress + freqTableSource = CompressionUtils.allocateByteBuffer(uncompressedLength); + final ByteBuffer compressedFrequencyTableBuffer = CompressionUtils.wrap(compressedFreqTable); + + // uncompress using RANSNx16 Order 0, Nway = 4 + // formatFlags = (~RANSNx16Params.ORDER_FLAG_MASK & ~RANSNx16Params.N32_FLAG_MASK) = ~(RANSNx16Params.ORDER_FLAG_MASK | RANSNx16Params.N32_FLAG_MASK) + uncompressOrder0WayN(compressedFrequencyTableBuffer, freqTableSource, uncompressedLength,new RANSNx16Params(~(RANSNx16Params.ORDER_FLAG_MASK | RANSNx16Params.N32_FLAG_MASK))); // format flags = 0 + } + else { + freqTableSource = inBuffer; + } + + // Moving initializeRANSDecoder() from the beginning of this method to this point in the code + // due to the nested call to uncompressOrder0WayN, which also invokes the initializeRANSDecoder() method. + // TODO: we should work on a more permanent solution for this issue! + initializeRANSDecoder(); + final int shift = frequencyTableFirstByte >> 4; + readFrequencyTableOrder1(freqTableSource, shift); + final int outputSize = outBuffer.remaining(); + + // Nway parallel rans states. Nway = 4 or 32 + final int Nway = ransNx16Params.getNumInterleavedRANSStates(); + final long[] rans = new long[Nway]; + final int[] interleaveStreamIndex = new int[Nway]; + final int[] context = new int[Nway]; + + // size of interleaved stream = outputSize / Nway + // For Nway = 4, division by 4 is the same as right shift by 2 bits + // For Nway = 32, division by 32 is the same as right shift by 5 bits + final int interleaveSize = (Nway==4) ? (outputSize >> 2): (outputSize >> 5); + + for (int r=0; r 0) { + decoder.frequencies[j] = CompressionUtils.readUint7(cp); + } + } + Utils.normaliseFrequenciesOrder0Shift(decoder.frequencies, Constants.TOTAL_FREQ_SHIFT); + + final RANSDecodingSymbol[] decodingSymbols = getDecodingSymbols()[0]; + int cumulativeFrequency = 0; + for (int j = 0; j < Constants.NUMBER_OF_SYMBOLS; j++) { + if(alphabet[j]>0){ + + // set RANSDecodingSymbol + decodingSymbols[j].set(cumulativeFrequency, decoder.frequencies[j]); + + // update Reverse Lookup table + Arrays.fill(decoder.reverseLookup, cumulativeFrequency, cumulativeFrequency + decoder.frequencies[j], (byte) j); + cumulativeFrequency += decoder.frequencies[j]; + } + } + } + + private void readFrequencyTableOrder1( + final ByteBuffer cp, + final int shift) { + final ArithmeticDecoder[] D = getD(); + final RANSDecodingSymbol[][] decodingSymbols = getDecodingSymbols(); + final int[] alphabet = readAlphabet(cp); + for (int i=0; i < Constants.NUMBER_OF_SYMBOLS; i++) { + if (alphabet[i] > 0) { + int run = 0; + for (int j = 0; j < Constants.NUMBER_OF_SYMBOLS; j++) { + if (alphabet[j] > 0) { + if (run > 0) { + run--; + } else { + D[i].frequencies[j] = CompressionUtils.readUint7(cp); + if (D[i].frequencies[j] == 0){ + run = cp.get() & 0xFF; + } + } + } + } + + // For each symbol, normalise it's order 0 frequency table + Utils.normaliseFrequenciesOrder0Shift(D[i].frequencies,shift); + int cumulativeFreq=0; + + // set decoding symbols + for (int j = 0; j < Constants.NUMBER_OF_SYMBOLS; j++) { + decodingSymbols[i][j].set( + cumulativeFreq, + D[i].frequencies[j] + ); + /* Build reverse lookup table */ + Arrays.fill(D[i].reverseLookup, cumulativeFreq, cumulativeFreq + D[i].frequencies[j], (byte) j); + cumulativeFreq+=D[i].frequencies[j]; + } + } + } + } + + private static int[] readAlphabet(final ByteBuffer cp){ + // gets the list of alphabets whose frequency!=0 + final int[] alphabet = new int[Constants.NUMBER_OF_SYMBOLS]; + int rle = 0; + int symbol = cp.get() & 0xFF; + int lastSymbol = symbol; + do { + alphabet[symbol] = 1; + if (rle!=0) { + rle--; + symbol++; + } else { + symbol = cp.get() & 0xFF; + if (symbol == lastSymbol+1) { + rle = cp.get() & 0xFF; + } + } + lastSymbol = symbol; + } while (symbol != 0); + return alphabet; + } + + private ByteBuffer decodeRLEMeta( + final ByteBuffer inBuffer, + final int uncompressedRLEMetaDataLength, + final int[] rleSymbols, + final RANSNx16Params ransNx16Params) { + final ByteBuffer uncompressedRLEMetaData; + + // The bottom bit of uncompressedRLEMetaDataLength is a flag to indicate + // whether rle metadata is uncompressed (1) or com- pressed (0). + if ((uncompressedRLEMetaDataLength & RLE_META_OPTIONALLY_COMPRESSED_MASK)!=0) { + final byte[] uncompressedRLEMetaDataArray = new byte[(uncompressedRLEMetaDataLength-1)/2]; + inBuffer.get(uncompressedRLEMetaDataArray, 0, (uncompressedRLEMetaDataLength-1)/2); + uncompressedRLEMetaData = CompressionUtils.wrap(uncompressedRLEMetaDataArray); + } else { + final int compressedRLEMetaDataLength = CompressionUtils.readUint7(inBuffer); + final byte[] compressedRLEMetaDataArray = new byte[compressedRLEMetaDataLength]; + inBuffer.get(compressedRLEMetaDataArray,0,compressedRLEMetaDataLength); + final ByteBuffer compressedRLEMetaData = CompressionUtils.wrap(compressedRLEMetaDataArray); + uncompressedRLEMetaData = CompressionUtils.allocateByteBuffer(uncompressedRLEMetaDataLength / 2); + // uncompress using Order 0 and N = Nway + uncompressOrder0WayN( + compressedRLEMetaData, + uncompressedRLEMetaData, + uncompressedRLEMetaDataLength / 2, + new RANSNx16Params(0x00 | ransNx16Params.getFormatFlags() & RANSNx16Params.N32_FLAG_MASK)); + } + + int numRLESymbols = uncompressedRLEMetaData.get() & 0xFF; + if (numRLESymbols == 0) { + numRLESymbols = Constants.NUMBER_OF_SYMBOLS; + } + for (int i = 0; i< numRLESymbols; i++) { + rleSymbols[uncompressedRLEMetaData.get() & 0xFF] = 1; + } + return uncompressedRLEMetaData; + } + + private ByteBuffer decodeRLE( + final ByteBuffer inBuffer, + final int[] rleSymbols, + final ByteBuffer uncompressedRLEMetaData, + final int uncompressedRLEOutputLength) { + final ByteBuffer rleOutBuffer = CompressionUtils.allocateByteBuffer(uncompressedRLEOutputLength); + int j = 0; + for(int i = 0; j< uncompressedRLEOutputLength; i++){ + final byte sym = inBuffer.get(i); + if (rleSymbols[sym & 0xFF]!=0){ + final int run = CompressionUtils.readUint7(uncompressedRLEMetaData); + for (int r=0; r<= run; r++){ + rleOutBuffer.put(j++, sym); + } + }else { + rleOutBuffer.put(j++, sym); + } + } + return rleOutBuffer; + } + + private ByteBuffer decodeStripe(final ByteBuffer inBuffer, final int outSize){ + final int numInterleaveStreams = inBuffer.get() & 0xFF; + + // read lengths of compressed interleaved streams + for ( int j=0; j j){ + uncompressedLengths[j]++; + } + + transposedData[j] = uncompress(inBuffer, uncompressedLengths[j]); + } + + // Transpose + final ByteBuffer outBuffer = CompressionUtils.allocateByteBuffer(outSize); + for (int j = 0; j { + ///////////////////////////////////////////////////////////////////////////////////////////////// + // Stripe flag is not implemented in the write implementation + ///////////////////////////////////////////////////////////////////////////////////////////////// + + private static final ByteBuffer EMPTY_BUFFER = CompressionUtils.allocateByteBuffer(0); + + // This method assumes that inBuffer is already rewound. + // It compresses the data in the inBuffer, leaving it consumed. + // Returns a rewound ByteBuffer containing the compressed data. + public ByteBuffer compress(final ByteBuffer inBuffer, final RANSNx16Params ransNx16Params) { + if (inBuffer.remaining() == 0) { + return EMPTY_BUFFER; + } + final ByteBuffer outBuffer = CompressionUtils.allocateOutputBuffer(inBuffer.remaining()); + final int formatFlags = ransNx16Params.getFormatFlags(); + outBuffer.put((byte) (formatFlags)); // one byte for formatFlags + + // NoSize + if (!ransNx16Params.isNosz()) { + // original size is not recorded + CompressionUtils.writeUint7(inBuffer.remaining(),outBuffer); + } + + ByteBuffer inputBuffer = inBuffer; + + // Stripe + // Stripe flag is not implemented in the write implementation + if (ransNx16Params.isStripe()) { + throw new CRAMException("RANSNx16 Encoding with Stripe Flag is not implemented."); + } + + // Pack + if (ransNx16Params.isPack()) { + final int[] frequencyTable = new int[Constants.NUMBER_OF_SYMBOLS]; + final int inSize = inputBuffer.remaining(); + for (int i = 0; i < inSize; i ++) { + frequencyTable[inputBuffer.get(i) & 0xFF]++; + } + int numSymbols = 0; + final int[] packMappingTable = new int[Constants.NUMBER_OF_SYMBOLS]; + for (int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i++) { + if (frequencyTable[i]>0) { + packMappingTable[i] = numSymbols++; + } + } + + // skip Packing if numSymbols = 0 or numSymbols > 16 + if (numSymbols !=0 && numSymbols <= 16) { + inputBuffer = CompressionUtils.encodePack(inputBuffer, outBuffer, frequencyTable, packMappingTable, numSymbols); + } else { + // unset pack flag in the first byte of the outBuffer + outBuffer.put(0,(byte)(outBuffer.get(0) & ~RANSNx16Params.PACK_FLAG_MASK)); + } + } + + // RLE + if (ransNx16Params.isRLE()){ + inputBuffer = encodeRLE(inputBuffer, outBuffer, ransNx16Params); + } + + if (ransNx16Params.isCAT()) { + // Data is uncompressed + outBuffer.put(inputBuffer); + outBuffer.limit(outBuffer.position()); + outBuffer.rewind(); // set position to 0 + return outBuffer; + } + + // if after encoding pack and rle, the inputBuffer size < Nway, then use order 0 + if (inputBuffer.remaining() < ransNx16Params.getNumInterleavedRANSStates() && ransNx16Params.getOrder() == RANSParams.ORDER.ONE) { + + // set order flag to "0" in the first byte of the outBuffer + outBuffer.put(0,(byte)(outBuffer.get(0) & ~RANSNx16Params.ORDER_FLAG_MASK)); + if (inputBuffer.remaining() == 0){ + outBuffer.limit(outBuffer.position()); + outBuffer.rewind(); + return outBuffer; + } + compressOrder0WayN(inputBuffer, new RANSNx16Params(outBuffer.get(0)), outBuffer); + return outBuffer; + } + + switch (ransNx16Params.getOrder()) { + case ZERO: + compressOrder0WayN(inputBuffer, ransNx16Params, outBuffer); + return outBuffer; + case ONE: + compressOrder1WayN(inputBuffer, ransNx16Params, outBuffer); + return outBuffer; + default: + throw new CRAMException("Unknown rANS order: " + ransNx16Params.getOrder()); + } + } + + private void compressOrder0WayN ( + final ByteBuffer inBuffer, + final RANSNx16Params ransNx16Params, + final ByteBuffer outBuffer) { + initializeRANSEncoder(); + final int inSize = inBuffer.remaining(); + int bitSize = (int) Math.ceil(Math.log(inSize) / Math.log(2)); + if (bitSize > Constants.TOTAL_FREQ_SHIFT) { + bitSize = Constants.TOTAL_FREQ_SHIFT; + } + final int prefix_size = outBuffer.position(); + final int[] F = buildFrequenciesOrder0(inBuffer); + final ByteBuffer cp = CompressionUtils.slice(outBuffer); + + // Normalize Frequencies such that sum of Frequencies = 1 << bitsize + Utils.normaliseFrequenciesOrder0(F, bitSize); + + // Write the Frequency table. Keep track of the size for later + final int frequencyTableSize = writeFrequenciesOrder0(cp, F); + + // Normalise Frequencies such that sum of Frequencies = 1 << 12 + // Since, Frequencies are already normalised to be a sum of power of 2, + // for further normalisation, calculate the bit shift that is required to scale the frequencies to (1 << bits) + if (bitSize != Constants.TOTAL_FREQ_SHIFT) { + Utils.normaliseFrequenciesOrder0Shift(F, Constants.TOTAL_FREQ_SHIFT); + } + + // using the normalised frequencies, set the RANSEncodingSymbols + buildSymsOrder0(F); + inBuffer.rewind(); + final int Nway = ransNx16Params.getNumInterleavedRANSStates(); + + // number of remaining elements = inputSize % Nway = inputSize - (interleaveSize * Nway) + // For Nway = 4, division by 4 is the same as right shift by 2 bits + // For Nway = 32, division by 32 is the same as right shift by 5 bits + final int inputSize = inBuffer.remaining(); + final int interleaveSize = (Nway == 4) ? (inputSize >> 2) : (inputSize >> 5); + int remainingSize = inputSize - (interleaveSize * Nway); + int reverseIndex = 1; + final long[] rans = new long[Nway]; + + // initialize rans states + for (int r=0; r0){ + + // encode remaining elements first + int remainingSymbol = 0xFF & inBuffer.get(inputSize - reverseIndex); + rans[remainingSize - 1] = ransEncodingSymbols[remainingSymbol].putSymbolNx16(rans[remainingSize - 1], ptr); + remainingSize --; + reverseIndex ++; + } + final byte[] symbol = new byte[Nway]; + for (int i = (interleaveSize * Nway); i > 0; i -= Nway) { + for (int r = Nway - 1; r >= 0; r--){ + + // encode using Nway parallel rans states. Nway = 4 or 32 + symbol[r] = inBuffer.get(i - (Nway - r)); + rans[r] = ransEncodingSymbols[0xFF & symbol[r]].putSymbolNx16(rans[r], ptr); + } + } + + ptr.order(ByteOrder.BIG_ENDIAN); + for (int i=Nway-1; i>=0; i--){ + ptr.putInt((int) rans[i]); + } + ptr.position(); + ptr.flip(); + final int compressedDataSize = ptr.limit(); + + // since the data is encoded in reverse order, + // reverse the compressed bytes, so that it is in correct order when uncompressed. + Utils.reverse(ptr); + inBuffer.position(inBuffer.limit()); + outBuffer.rewind(); // set position to 0 + outBuffer.limit(prefix_size + frequencyTableSize + compressedDataSize); + } + + private void compressOrder1WayN ( + final ByteBuffer inBuffer, + final RANSNx16Params ransNx16Params, + final ByteBuffer outBuffer) { + final int[][] frequencies = buildFrequenciesOrder1(inBuffer, ransNx16Params.getNumInterleavedRANSStates()); + + // normalise frequencies with a variable shift calculated + // using the minimum bit size that is needed to represent a frequency context array + Utils.normaliseFrequenciesOrder1(frequencies, Constants.TOTAL_FREQ_SHIFT); + final int prefix_size = outBuffer.position(); + + ByteBuffer frequencyTable = CompressionUtils.allocateOutputBuffer(1); + final ByteBuffer compressedFrequencyTable = CompressionUtils.allocateOutputBuffer(1); + + // uncompressed frequency table + final int uncompressedFrequencyTableSize = writeFrequenciesOrder1(frequencyTable,frequencies); + frequencyTable.limit(uncompressedFrequencyTableSize); + frequencyTable.rewind(); + + // Compress using RANSNx16 Order 0, Nway = 4. + // formatFlags = (~RANSNx16Params.ORDER_FLAG_MASK & ~RANSNx16Params.N32_FLAG_MASK) = ~(RANSNx16Params.ORDER_FLAG_MASK | RANSNx16Params.N32_FLAG_MASK) + compressOrder0WayN(frequencyTable, new RANSNx16Params(~(RANSNx16Params.ORDER_FLAG_MASK | RANSNx16Params.N32_FLAG_MASK)), compressedFrequencyTable); + frequencyTable.rewind(); + + // Moving initializeRANSEncoder() from the beginning of this method to this point in the code + // due to the nested call to compressOrder0WayN, which also invokes the initializeRANSEncoder() method. + // TODO: we should work on a more permanent solution for this issue! + initializeRANSEncoder(); + final int compressedFrequencyTableSize = compressedFrequencyTable.limit(); + final ByteBuffer cp = CompressionUtils.slice(outBuffer); + + // spec: The order-1 frequency table itself may still be quite large, + // so is optionally compressed using the order-0 rANSNx16 codec with a fixed 4-way interleaving. + if (compressedFrequencyTableSize < uncompressedFrequencyTableSize) { + + // first byte + cp.put((byte) (1 | Constants.TOTAL_FREQ_SHIFT << 4 )); + CompressionUtils.writeUint7(uncompressedFrequencyTableSize,cp); + CompressionUtils.writeUint7(compressedFrequencyTableSize,cp); + + // write bytes from compressedFrequencyTable to cp + int i=0; + while (i> 2: inputSize >> 5; + final int[] interleaveStreamIndex = new int[Nway]; + final byte[] symbol = new byte[Nway]; + for (int r=0; r= 0) && (r!= Nway-1)){ + symbol[r] = inBuffer.get(interleaveStreamIndex[r] + 1); + } + if ( r == Nway-1 ){ + symbol[r] = inBuffer.get(inputSize - 1); + } + } + + // Slicing is needed for buffer reversing later. + final ByteBuffer ptr = CompressionUtils.slice(cp); + final RANSEncodingSymbol[][] ransEncodingSymbols = getEncodingSymbols(); + final byte[] context = new byte[Nway]; + + // deal with the reminder + for ( + interleaveStreamIndex[Nway - 1] = inputSize - 2; + interleaveStreamIndex[Nway - 1] > Nway * interleaveSize - 2 && interleaveStreamIndex[Nway - 1] >= 0; + interleaveStreamIndex[Nway - 1]-- ) { + context[Nway - 1] = inBuffer.get(interleaveStreamIndex[Nway - 1]); + rans[Nway - 1] = ransEncodingSymbols[0xFF & context[Nway - 1]][0xFF & symbol[Nway - 1]].putSymbolNx16(rans[Nway - 1], ptr); + symbol[Nway - 1] = context[Nway - 1]; + } + + while (interleaveStreamIndex[0] >= 0) { + for (int r=0; r=0; r-- ){ + ptr.putInt((int) rans[r]); + } + + ptr.flip(); + final int compressedBlobSize = ptr.limit(); + Utils.reverse(ptr); + + /* + * Depletion of the in buffer cannot be confirmed because of the get(int + * position) method use during encoding, hence enforcing: + */ + inBuffer.position(inBuffer.limit()); + outBuffer.rewind(); + outBuffer.limit(prefix_size + frequencyTableSize + compressedBlobSize); + } + + private static int[] buildFrequenciesOrder0(final ByteBuffer inBuffer) { + // Returns an array of raw symbol frequencies + final int inSize = inBuffer.remaining(); + final int[] F = new int[Constants.NUMBER_OF_SYMBOLS]; + for (int i = 0; i < inSize; i++) { + F[0xFF & inBuffer.get()]++; + } + return F; + } + + private static int[][] buildFrequenciesOrder1(final ByteBuffer inBuffer, final int Nway) { + // Returns an array of raw symbol frequencies + final int inputSize = inBuffer.remaining(); + + // context is stored in frequency[Constants.NUMBER_OF_SYMBOLS] array + final int[][] frequency = new int[Constants.NUMBER_OF_SYMBOLS+1][Constants.NUMBER_OF_SYMBOLS]; + + // ‘\0’ is the initial context + byte contextSymbol = 0; + for (int i = 0; i < inputSize; i++) { + + // update the context array + frequency[Constants.NUMBER_OF_SYMBOLS][0xFF & contextSymbol]++; + final byte srcSymbol = inBuffer.get(i); + frequency[0xFF & contextSymbol][0xFF & srcSymbol ]++; + contextSymbol = srcSymbol; + } + frequency[Constants.NUMBER_OF_SYMBOLS][0xFF & contextSymbol]++; + + // set ‘\0’ as context for the first byte in the N interleaved streams. + // the first byte of the first interleaved stream is already accounted for. + for (int n = 1; n < Nway; n++){ + // For Nway = 4, division by 4 is the same as right shift by 2 bits + // For Nway = 32, division by 32 is the same as right shift by 5 bits + final int symbol = Nway == 4 ? (0xFF & inBuffer.get((n*(inputSize >> 2)))) : (0xFF & inBuffer.get((n*(inputSize >> 5)))); + frequency[0][symbol]++; + } + frequency[Constants.NUMBER_OF_SYMBOLS][0] += Nway-1; + return frequency; + } + + private static int writeFrequenciesOrder0(final ByteBuffer cp, final int[] F) { + // Order 0 frequencies store the complete alphabet of observed + // symbols using run length encoding, followed by a table of frequencies + // for each symbol in the alphabet. + final int start = cp.position(); + + // write the alphabet first and then their frequencies + writeAlphabet(cp,F); + for (int j = 0; j < Constants.NUMBER_OF_SYMBOLS; j++) { + if (F[j] != 0) { + if (F[j] < 128) { + cp.put((byte) (F[j] & 0x7f)); + } else { + + // if F[j] >127, it is written in 2 bytes + // right shift by 7 and get the most Significant Bits. + // Set the Most Significant Bit of the first byte to 1 indicating that the frequency comprises of 2 bytes + cp.put((byte) (128 | (F[j] >> 7))); + cp.put((byte) (F[j] & 0x7f)); //Least Significant 7 Bits + } + } + } + return cp.position() - start; + } + + private static int writeFrequenciesOrder1(final ByteBuffer cp, final int[][] F) { + final int start = cp.position(); + + // writeAlphabet uses rle to write all the symbols whose frequency!=0 + writeAlphabet(cp,F[Constants.NUMBER_OF_SYMBOLS]); + + for (int i=0; i 0) { + run--; + } else { + CompressionUtils.writeUint7(F[i][j],cp); + if (F[i][j] == 0) { + // Count how many more zero-freqs we have + for (int k = j+1; k < Constants.NUMBER_OF_SYMBOLS; k++) { + if (F[Constants.NUMBER_OF_SYMBOLS][k] == 0) { + continue; + } + if (F[i][k] == 0) { + run++; + } else { + break; + } + } + cp.put((byte) run); + } + } + } + } + return cp.position() - start; + } + + private static void writeAlphabet(final ByteBuffer cp, final int[] F) { + // Uses Run Length Encoding to write all the symbols whose frequency!=0 + int rle = 0; + for (int j = 0; j < Constants.NUMBER_OF_SYMBOLS; j++) { + if (F[j] != 0) { + if (rle != 0) { + rle--; + } else { + + // write the symbol if it is the first symbol or if rle = 0. + // if rle != 0, then skip writing the symbol. + cp.put((byte) j); + + // We've encoded two symbol frequencies in a row. + // How many more are there? Store that count so + // we can avoid writing consecutive symbols. + // Note: maximum possible rle = 254 + // rle requires atmost 1 byte + if (rle == 0 && j != 0 && F[j - 1] != 0) { + for (rle = j + 1; rle < Constants.NUMBER_OF_SYMBOLS && F[rle] != 0; rle++); + rle -= j + 1; + cp.put((byte) rle); + } + } + } + } + + // write 0 indicating the end of alphabet + cp.put((byte) 0); + } + + private ByteBuffer encodeRLE(final ByteBuffer inBuffer, final ByteBuffer outBuffer, final RANSNx16Params ransNx16Params){ + + // Find the symbols that benefit from RLE, i.e, the symbols that occur more than 2 times in succession. + // spec: For symbols that occur many times in succession, we can replace them with a single symbol and a count. + final int[] runCounts = new int[Constants.NUMBER_OF_SYMBOLS]; + int inputSize = inBuffer.remaining(); + + int lastSymbol = -1; + for (int i = 0; i < inputSize; i++) { + final int currentSymbol = inBuffer.get(i)&0xFF; + runCounts[currentSymbol] += (currentSymbol==lastSymbol ? 1:-1); + lastSymbol = currentSymbol; + } + + // numRLESymbols is the number of symbols that are run length encoded + int numRLESymbols = 0; + for (int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i++) { + if (runCounts[i]>0) { + numRLESymbols++; + } + } + + if (numRLESymbols==0) { + // Format cannot cope with zero RLE symbols, so pick one! + numRLESymbols = 1; + runCounts[0] = 1; + } + + // create rleMetaData buffer to store rle metadata. + // This buffer will be compressed using compressOrder0WayN towards the end of this method + // TODO: How did we come up with this calculation for Buffer size? numRLESymbols+1+inputSize + final ByteBuffer rleMetaData = CompressionUtils.allocateByteBuffer(numRLESymbols+1+inputSize); // rleMetaData + + // write number of symbols that are run length encoded + rleMetaData.put((byte) numRLESymbols); + + for (int i=0; i0){ + // write the symbols that are run length encoded + rleMetaData.put((byte) i); + } + + } + + // Apply RLE + // encodedBuffer -> input src data without repetition + final ByteBuffer encodedBuffer = CompressionUtils.allocateByteBuffer(inputSize); // rleInBuffer + int encodedBufferIdx = 0; // rleInBufferIndex + + for (int i = 0; i < inputSize; i++) { + encodedBuffer.put(encodedBufferIdx++,inBuffer.get(i)); + if (runCounts[inBuffer.get(i)&0xFF]>0) { + lastSymbol = inBuffer.get(i) & 0xFF; + int run = 0; + + // calculate the run value for current symbol + while (i+run+1 < inputSize && (inBuffer.get(i+run+1)& 0xFF)==lastSymbol) { + run++; + } + + // write the run value to metadata + CompressionUtils.writeUint7(run, rleMetaData); + + // go to the next element that is not equal to its previous element + i += run; + } + } + + encodedBuffer.limit(encodedBufferIdx); + // limit and rewind + rleMetaData.limit(rleMetaData.position()); + rleMetaData.rewind(); + + // compress the rleMetaData Buffer + final ByteBuffer compressedRleMetaData = CompressionUtils.allocateOutputBuffer(rleMetaData.remaining()); + + // compress using Order 0 and N = Nway + compressOrder0WayN(rleMetaData, new RANSNx16Params(0x00 | ransNx16Params.getFormatFlags() & RANSNx16Params.N32_FLAG_MASK),compressedRleMetaData); + + // write to compressedRleMetaData to outBuffer + CompressionUtils.writeUint7(rleMetaData.limit()*2, outBuffer); + CompressionUtils.writeUint7(encodedBufferIdx, outBuffer); + CompressionUtils.writeUint7(compressedRleMetaData.limit(),outBuffer); + + outBuffer.put(compressedRleMetaData); + + /* + * Depletion of the inBuffer cannot be confirmed because of the get(int + * position) method use during encoding, hence enforcing: + */ + inBuffer.position(inBuffer.limit()); + return encodedBuffer; + } + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/compression/rans/ransnx16/RANSNx16Params.java b/src/main/java/htsjdk/samtools/cram/compression/rans/ransnx16/RANSNx16Params.java new file mode 100644 index 0000000000..93bd529f27 --- /dev/null +++ b/src/main/java/htsjdk/samtools/cram/compression/rans/ransnx16/RANSNx16Params.java @@ -0,0 +1,73 @@ +package htsjdk.samtools.cram.compression.rans.ransnx16; + +import htsjdk.samtools.cram.compression.rans.RANSParams; + +public class RANSNx16Params implements RANSParams { + + // RANS Nx16 Bit Flags + public static final int ORDER_FLAG_MASK = 0x01; + public static final int N32_FLAG_MASK = 0x04; + public static final int STRIPE_FLAG_MASK = 0x08; + public static final int NOSZ_FLAG_MASK = 0x10; + public static final int CAT_FLAG_MASK = 0x20; + public static final int RLE_FLAG_MASK = 0x40; + public static final int PACK_FLAG_MASK = 0x80; + + // format is the first byte of the compressed data stream, + // which consists of all the bit-flags detailing the type of transformations + // and entropy encoders to be combined + private int formatFlags; + + private static final int FORMAT_FLAG_MASK = 0xFF; + + public RANSNx16Params(final int formatFlags) { + this.formatFlags = formatFlags; + } + + @Override + public String toString() { + return "RANSNx16Params{" + "formatFlags=" + formatFlags + "}"; + } + + @Override + public ORDER getOrder() { + // Rans Order ZERO or ONE encoding + return ORDER.fromInt(formatFlags & ORDER_FLAG_MASK); //convert into order type + } + + public int getFormatFlags(){ + // first byte of the encoded stream + return formatFlags & FORMAT_FLAG_MASK; + } + + public int getNumInterleavedRANSStates(){ + // Interleave N = 32 rANS states (else N = 4) + return ((formatFlags & N32_FLAG_MASK) == 0) ? 4 : 32; + } + + public boolean isStripe(){ + // multiway interleaving of byte streams + return ((formatFlags & STRIPE_FLAG_MASK)!=0); + } + + public boolean isNosz(){ + // original size is not recorded (for use by Stripe) + return ((formatFlags & NOSZ_FLAG_MASK)!=0); + } + + public boolean isCAT(){ + // Data is uncompressed + return ((formatFlags & CAT_FLAG_MASK)!=0); + } + + public boolean isRLE(){ + // Run length encoding, with runs and literals encoded separately + return ((formatFlags & RLE_FLAG_MASK)!=0); + } + + public boolean isPack(){ + // Pack 2, 4, 8 or infinite symbols per byte + return ((formatFlags & PACK_FLAG_MASK)!=0); + } + +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/structure/CompressionHeaderEncodingMap.java b/src/main/java/htsjdk/samtools/cram/structure/CompressionHeaderEncodingMap.java index 638089c5e0..123c361132 100644 --- a/src/main/java/htsjdk/samtools/cram/structure/CompressionHeaderEncodingMap.java +++ b/src/main/java/htsjdk/samtools/cram/structure/CompressionHeaderEncodingMap.java @@ -26,7 +26,7 @@ import htsjdk.samtools.cram.CRAMException; import htsjdk.samtools.cram.compression.ExternalCompressor; -import htsjdk.samtools.cram.compression.rans.RANS; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Params; import htsjdk.samtools.cram.encoding.CRAMEncoding; import htsjdk.samtools.cram.encoding.external.ByteArrayStopEncoding; import htsjdk.samtools.cram.encoding.external.ExternalByteEncoding; @@ -38,10 +38,20 @@ import htsjdk.samtools.cram.structure.block.BlockCompressionMethod; import htsjdk.utils.ValidationUtils; -import java.io.*; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.nio.ByteBuffer; -import java.util.*; import htsjdk.samtools.util.Log; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; /** * Maintains a map of DataSeries to EncodingDescriptor, and a second map that contains the compressor to use @@ -278,12 +288,12 @@ public ExternalCompressor getBestExternalCompressor(final byte[] data, final CRA final ExternalCompressor rans0 = compressorCache.getCompressorForMethod( BlockCompressionMethod.RANS, - RANS.ORDER.ZERO.ordinal()); + RANS4x8Params.ORDER.ZERO.ordinal()); final int rans0Len = rans0.compress(data).length; final ExternalCompressor rans1 = compressorCache.getCompressorForMethod( BlockCompressionMethod.RANS, - RANS.ORDER.ONE.ordinal()); + RANS4x8Params.ORDER.ONE.ordinal()); final int rans1Len = rans1.compress(data).length; // find the best of general purpose codecs: @@ -387,14 +397,14 @@ private void putExternalGzipEncoding(final CRAMEncodingStrategy encodingStrategy private void putExternalRansOrderOneEncoding(final DataSeries dataSeries) { putExternalEncoding( dataSeries, - compressorCache.getCompressorForMethod(BlockCompressionMethod.RANS, RANS.ORDER.ONE.ordinal())); + compressorCache.getCompressorForMethod(BlockCompressionMethod.RANS, RANS4x8Params.ORDER.ONE.ordinal())); } // add an external encoding appropriate for the dataSeries value type, with a RANS order 0 compressor private void putExternalRansOrderZeroEncoding(final DataSeries dataSeries) { putExternalEncoding( dataSeries, - compressorCache.getCompressorForMethod(BlockCompressionMethod.RANS, RANS.ORDER.ZERO.ordinal())); + compressorCache.getCompressorForMethod(BlockCompressionMethod.RANS, RANS4x8Params.ORDER.ZERO.ordinal())); } @Override @@ -415,4 +425,4 @@ public int hashCode() { return result; } -} +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/structure/CompressorCache.java b/src/main/java/htsjdk/samtools/cram/structure/CompressorCache.java index 7021664be3..81b4f98199 100644 --- a/src/main/java/htsjdk/samtools/cram/structure/CompressorCache.java +++ b/src/main/java/htsjdk/samtools/cram/structure/CompressorCache.java @@ -24,8 +24,11 @@ */ package htsjdk.samtools.cram.structure; -import htsjdk.samtools.cram.compression.*; -import htsjdk.samtools.cram.compression.rans.RANS; +import htsjdk.samtools.cram.compression.ExternalCompressor; +import htsjdk.samtools.cram.compression.RANSExternalCompressor; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Decode; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Encode; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Params; import htsjdk.samtools.cram.structure.block.BlockCompressionMethod; import htsjdk.samtools.util.Tuple; import htsjdk.utils.ValidationUtils; @@ -40,7 +43,8 @@ public class CompressorCache { private final String argErrorMessage = "Invalid compression arg (%d) requested for CRAM %s compressor"; private final HashMap, ExternalCompressor> compressorCache = new HashMap<>(); - private RANS sharedRANS; + private RANS4x8Encode sharedRANSEncode; + private RANS4x8Decode sharedRANSDecode; /** * Return a compressor if its in our cache, otherwise spin one up and cache it and return it. @@ -67,22 +71,26 @@ public ExternalCompressor getCompressorForMethod( // for efficiency, we want to share the same underlying RANS object with both order-0 and // order-1 ExternalCompressors final int ransArg = compressorSpecificArg == ExternalCompressor.NO_COMPRESSION_ARG ? - RANS.ORDER.ZERO.ordinal() : + RANS4x8Params.ORDER.ZERO.ordinal() : compressorSpecificArg; final Tuple compressorTuple = new Tuple<>( BlockCompressionMethod.RANS, ransArg); if (!compressorCache.containsKey(compressorTuple)) { - if (sharedRANS == null) { - sharedRANS = new RANS(); + if (sharedRANSEncode == null) { + sharedRANSEncode = new RANS4x8Encode(); + } + if (sharedRANSDecode == null) { + sharedRANSDecode = new RANS4x8Decode(); } compressorCache.put( new Tuple(BlockCompressionMethod.RANS, ransArg), - new RANSExternalCompressor(ransArg, sharedRANS) + new RANSExternalCompressor(ransArg, sharedRANSEncode, sharedRANSDecode) ); } return getCachedCompressorForMethod(compressorTuple.a, compressorTuple.b); - + case RANGE: + return getCachedCompressorForMethod(compressionMethod, compressorSpecificArg); default: throw new IllegalArgumentException(String.format("Unknown compression method %s", compressionMethod)); } @@ -97,4 +105,4 @@ private ExternalCompressor getCachedCompressorForMethod(final BlockCompressionMe ); } -} +} \ No newline at end of file diff --git a/src/main/java/htsjdk/samtools/cram/structure/block/BlockCompressionMethod.java b/src/main/java/htsjdk/samtools/cram/structure/block/BlockCompressionMethod.java index f37b82e463..d4b1c8aa7a 100644 --- a/src/main/java/htsjdk/samtools/cram/structure/block/BlockCompressionMethod.java +++ b/src/main/java/htsjdk/samtools/cram/structure/block/BlockCompressionMethod.java @@ -32,7 +32,8 @@ public enum BlockCompressionMethod { GZIP(1), BZIP2(2), LZMA(3), - RANS(4); + RANS(4), + RANGE(5); private final int methodId; @@ -65,4 +66,4 @@ public static BlockCompressionMethod byId(final int id) { private static final Map ID_MAP = Collections.unmodifiableMap(Stream.of(BlockCompressionMethod.values()) .collect(Collectors.toMap(BlockCompressionMethod::getMethodId, Function.identity()))); -} +} \ No newline at end of file diff --git a/src/test/java/htsjdk/samtools/cram/CRAMInteropTestUtils.java b/src/test/java/htsjdk/samtools/cram/CRAMInteropTestUtils.java new file mode 100644 index 0000000000..eaee961bf5 --- /dev/null +++ b/src/test/java/htsjdk/samtools/cram/CRAMInteropTestUtils.java @@ -0,0 +1,97 @@ +package htsjdk.samtools.cram; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; + +import htsjdk.utils.SamtoolsTestUtils; + +/** + * Interop test data is kept in a separate repository, currently at https://github.com/samtools/htscodecs + * so it can be shared across htslib/samtools/htsjdk. + */ +public class CRAMInteropTestUtils { + public static final String INTEROP_TEST_FILES_PATH = SamtoolsTestUtils.getCRAMInteropData(); + + /** + * @return true if interop test data is available, otherwise false + */ + public static boolean isInteropTestDataAvailable() { + final Path testDataPath = getInteropTestDataLocation(); + return Files.exists(testDataPath); + } + + /** + * @return the name and location of the local interop test data as specified by the + * variable INTEROP_TEST_FILES_PATH + */ + public static Path getInteropTestDataLocation() { + return Paths.get(INTEROP_TEST_FILES_PATH); + } + + // the input files have embedded newlines that the test remove before round-tripping... + protected static final byte[] filterEmbeddedNewlines(final byte[] rawBytes) throws IOException { + // 1. filters new lines if any. + // 2. "q40+dir" file has an extra column delimited by tab. This column provides READ1 vs READ2 flag. + // This file is also new-line separated. The extra column, '\t' and '\n' are filtered. + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + int skip = 0; + for (final byte b : rawBytes) { + if (b == '\t'){ + skip = 1; + } + if (b == '\n') { + skip = 0; + } + if (skip == 0 && b !='\n') { + baos.write(b); + } + } + return baos.toByteArray(); + } + } + + // return a list of all encoded test data files in the htscodecs/tests/dat/ directory + protected static List getInteropCompressedFilePaths(final String compressedDir) throws IOException { + final List paths = new ArrayList<>(); + Files.newDirectoryStream( + CRAMInteropTestUtils.getInteropTestDataLocation().resolve("dat/"+compressedDir), + path -> Files.isRegularFile(path)) + .forEach(path -> paths.add(path)); + return paths; + } + + // Given a compressed test file path, return the corresponding uncompressed file path + protected static final Path getUnCompressedFilePath(final Path compressedInteropPath) { + final String uncompressedFileName = getUncompressedFileName(compressedInteropPath.getFileName().toString()); + // Example compressedInteropPath: ../dat/r4x8/q4.1 => unCompressedFilePath: ../dat/q4 + return compressedInteropPath.getParent().getParent().resolve(uncompressedFileName); + } + + private static final String getUncompressedFileName(final String compressedFileName) { + // Returns original filename from compressed file name + final int lastDotIndex = compressedFileName.lastIndexOf("."); + if (lastDotIndex >= 0) { + return compressedFileName.substring(0, lastDotIndex); + } else { + throw new CRAMException("The format of the compressed File Name is not as expected. " + + "The name of the compressed file should contain a period followed by a number that" + + "indicates the order of compression. Actual compressed file name = "+ compressedFileName); + } + } + + // return a list of all raw test files in the htscodecs/tests/dat directory + protected static final List getInteropRawTestFiles() throws IOException { + final List paths = new ArrayList<>(); + Files.newDirectoryStream( + CRAMInteropTestUtils.getInteropTestDataLocation().resolve("dat"), + path -> (Files.isRegularFile(path)) && !Files.isHidden(path)) + .forEach(path -> paths.add(path)); + return paths; + } + +} \ No newline at end of file diff --git a/src/test/java/htsjdk/samtools/cram/NameTokenizationInteropTest.java b/src/test/java/htsjdk/samtools/cram/NameTokenizationInteropTest.java new file mode 100644 index 0000000000..8a4aa0e22b --- /dev/null +++ b/src/test/java/htsjdk/samtools/cram/NameTokenizationInteropTest.java @@ -0,0 +1,134 @@ +package htsjdk.samtools.cram; + +import htsjdk.HtsjdkTest; +import htsjdk.samtools.cram.compression.nametokenisation.NameTokenisationDecode; +import htsjdk.samtools.cram.compression.nametokenisation.NameTokenisationEncode; +import org.apache.commons.compress.utils.IOUtils; +import org.testng.Assert; +import org.testng.SkipException; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; + +public class NameTokenizationInteropTest extends HtsjdkTest { + public static final String COMPRESSED_TOK_DIR = "tok3"; + + @DataProvider(name = "allNameTokenizationFiles") + public Object[][] getAllNameTokenizationCodecsForRoundTrip() throws IOException { + + // params: + // compressed testfile path, uncompressed testfile path, NameTokenization encoder, NameTokenization decoder + final List testCases = new ArrayList<>(); + for (Path path : getInteropNameTokenizationCompressedFiles()) { + Object[] objects = new Object[]{ + path, + getNameTokenizationUnCompressedFilePath(path), + new NameTokenisationEncode(), + new NameTokenisationDecode() + }; + testCases.add(objects); + } + return testCases.toArray(new Object[][]{}); + } + + @Test(description = "Test if CRAM Interop Test Data is available") + public void testGetHTSCodecsCorpus() { + if (!CRAMInteropTestUtils.isInteropTestDataAvailable()) { + throw new SkipException(String.format("CRAM Interop Test Data is not available at %s", + CRAMInteropTestUtils.INTEROP_TEST_FILES_PATH)); + } + } + + @Test ( + dependsOnMethods = "testGetHTSCodecsCorpus", + dataProvider = "allNameTokenizationFiles", + description = "Roundtrip using htsjdk NameTokenization Codec. Compare the output with the original file" ) + public void testRangeRoundTrip( + final Path precompressedFilePath, + final Path uncompressedFilePath, + final NameTokenisationEncode nameTokenisationEncode, + final NameTokenisationDecode nameTokenisationDecode) throws IOException { + try(final InputStream preCompressedInteropStream = Files.newInputStream(precompressedFilePath); + final InputStream unCompressedInteropStream = Files.newInputStream(uncompressedFilePath)){ + final ByteBuffer preCompressedInteropBytes = ByteBuffer.wrap(IOUtils.toByteArray(preCompressedInteropStream)); + final ByteBuffer unCompressedInteropBytes = ByteBuffer.wrap(IOUtils.toByteArray(unCompressedInteropStream)); + ByteBuffer compressedHtsjdkBytes = nameTokenisationEncode.compress(unCompressedInteropBytes); + String decompressedHtsjdkString = nameTokenisationDecode.uncompress(compressedHtsjdkBytes); + ByteBuffer decompressedHtsjdkBytes = StandardCharsets.UTF_8.encode(decompressedHtsjdkString); + unCompressedInteropBytes.rewind(); + Assert.assertEquals(decompressedHtsjdkBytes, unCompressedInteropBytes); + } catch (final NoSuchFileException ex){ + throw new SkipException("Skipping testRangeRoundTrip as either the input precompressed file " + + "or the uncompressed file is missing.", ex); + } + } + + + + @Test ( + dependsOnMethods = "testGetHTSCodecsCorpus", + dataProvider = "allNameTokenizationFiles", + description = "Compress the original file using htsjdk NameTokenization Codec and compare it with the existing compressed file. " + + "Uncompress the existing compressed file using htsjdk NameTokenization Codec and compare it with the original file.") + public void testtNameTokenizationPreCompressed( + final Path compressedFilePath, + final Path uncompressedFilePath, + final NameTokenisationEncode unsusednameTokenisationEncode, + final NameTokenisationDecode nameTokenisationDecode) throws IOException { + try(final InputStream preCompressedInteropStream = Files.newInputStream(compressedFilePath); + final InputStream unCompressedInteropStream = Files.newInputStream(uncompressedFilePath)){ + final ByteBuffer preCompressedInteropBytes = ByteBuffer.wrap(IOUtils.toByteArray(preCompressedInteropStream)); + final ByteBuffer unCompressedInteropBytes = ByteBuffer.wrap(IOUtils.toByteArray(unCompressedInteropStream)); + + // Use htsjdk to uncompress the precompressed file from htscodecs repo + final String uncompressedHtsjdkString = nameTokenisationDecode.uncompress(preCompressedInteropBytes); + ByteBuffer uncompressedHtsjdkBytes = StandardCharsets.UTF_8.encode(uncompressedHtsjdkString); + + // Compare the htsjdk uncompressed bytes with the original input file from htscodecs repo + Assert.assertEquals(uncompressedHtsjdkBytes, unCompressedInteropBytes); + } catch (final NoSuchFileException ex){ + throw new SkipException("Skipping testNameTokenizationPrecompressed as either input file " + + "or precompressed file is missing.", ex); + } + + } + + // return a list of all NameTokenization encoded test data files in the htscodecs/tests/names/tok3 directory + private List getInteropNameTokenizationCompressedFiles() throws IOException { + final List paths = new ArrayList<>(); + Files.newDirectoryStream( + CRAMInteropTestUtils.getInteropTestDataLocation().resolve("names/"+COMPRESSED_TOK_DIR), + path -> Files.isRegularFile(path)) + .forEach(path -> paths.add(path)); + return paths; + } + + // Given a compressed test file path, return the corresponding uncompressed file path + public static final Path getNameTokenizationUnCompressedFilePath(final Path compressedInteropPath) { + String uncompressedFileName = getUncompressedFileName(compressedInteropPath.getFileName().toString()); + // Example compressedInteropPath: ../names/tok3/01.names.1 => unCompressedFilePath: ../names/01.names + return compressedInteropPath.getParent().getParent().resolve(uncompressedFileName); + } + + public static final String getUncompressedFileName(final String compressedFileName) { + // Returns original filename from compressed file name + int lastDotIndex = compressedFileName.lastIndexOf("."); + if (lastDotIndex >= 0) { + return compressedFileName.substring(0, lastDotIndex); + } else { + throw new CRAMException("The format of the compressed File Name is not as expected. " + + "The name of the compressed file should contain a period followed by a number that" + + "indicates type of compression. Actual compressed file name = "+ compressedFileName); + } + } + +} \ No newline at end of file diff --git a/src/test/java/htsjdk/samtools/cram/RANSInteropTest.java b/src/test/java/htsjdk/samtools/cram/RANSInteropTest.java new file mode 100644 index 0000000000..3b5358075c --- /dev/null +++ b/src/test/java/htsjdk/samtools/cram/RANSInteropTest.java @@ -0,0 +1,219 @@ +package htsjdk.samtools.cram; + +import htsjdk.HtsjdkTest; +import htsjdk.samtools.cram.compression.CompressionUtils; +import htsjdk.samtools.cram.compression.rans.RANSDecode; +import htsjdk.samtools.cram.compression.rans.RANSEncode; +import htsjdk.samtools.cram.compression.rans.RANSParams; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Decode; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Encode; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Params; +import htsjdk.samtools.cram.compression.rans.ransnx16.RANSNx16Decode; +import htsjdk.samtools.cram.compression.rans.ransnx16.RANSNx16Encode; +import htsjdk.samtools.cram.compression.rans.ransnx16.RANSNx16Params; +import org.apache.commons.compress.utils.IOUtils; +import org.testng.Assert; +import org.testng.SkipException; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +/** + * RANSInteropTest tests if the htsjdk RANS4x8 and RANSNx16 implementations are interoperable + * with the htslib implementations. The test files for Interop tests is kept in a separate repository, + * currently at https://github.com/samtools/htscodecs so it can be shared across htslib/samtools/htsjdk. + * + * For local development env, the Interop test files must be downloaded locally and made available at "../htscodecs/tests" + * For CI env, the Interop test files are made available from the existing samtools installation + * at "/samtools-1.14/htslib-1.14/htscodecs/tests" + */ +public class RANSInteropTest extends HtsjdkTest { + public static final String COMPRESSED_RANS4X8_DIR = "r4x8"; + public static final String COMPRESSED_RANSNX16_DIR = "r4x16"; + + // enumerates the different flag combinations + public Object[][] get4x8RoundTripTestCases() throws IOException { + + // params: + // uncompressed testfile path, + // RANS encoder, RANS decoder, RANS params + final List rans4x8ParamsOrderList = Arrays.asList( + RANSParams.ORDER.ZERO, + RANSParams.ORDER.ONE); + final List testCases = new ArrayList<>(); + CRAMInteropTestUtils.getInteropRawTestFiles() + .forEach(path -> + rans4x8ParamsOrderList.stream().map(rans4x8ParamsOrder -> new Object[]{ + path, + new RANS4x8Encode(), + new RANS4x8Decode(), + new RANS4x8Params(rans4x8ParamsOrder) + }).forEach(testCases::add)); + return testCases.toArray(new Object[][]{}); + } + + // enumerates the different flag combinations + public Object[][] getNx16RoundTripTestCases() throws IOException { + + // params: + // uncompressed testfile path, + // RANS encoder, RANS decoder, RANS params + final List ransNx16ParamsFormatFlagList = Arrays.asList( + 0x00, + RANSNx16Params.ORDER_FLAG_MASK, + RANSNx16Params.N32_FLAG_MASK, + RANSNx16Params.N32_FLAG_MASK | RANSNx16Params.ORDER_FLAG_MASK, + RANSNx16Params.CAT_FLAG_MASK, + RANSNx16Params.CAT_FLAG_MASK | RANSNx16Params.ORDER_FLAG_MASK, + RANSNx16Params.CAT_FLAG_MASK | RANSNx16Params.N32_FLAG_MASK, + RANSNx16Params.CAT_FLAG_MASK | RANSNx16Params.N32_FLAG_MASK | RANSNx16Params.ORDER_FLAG_MASK, + RANSNx16Params.RLE_FLAG_MASK, + RANSNx16Params.RLE_FLAG_MASK | RANSNx16Params.ORDER_FLAG_MASK, + RANSNx16Params.PACK_FLAG_MASK, + RANSNx16Params.PACK_FLAG_MASK | RANSNx16Params.ORDER_FLAG_MASK, + RANSNx16Params.RLE_FLAG_MASK | RANSNx16Params.PACK_FLAG_MASK, + RANSNx16Params.RLE_FLAG_MASK | RANSNx16Params.PACK_FLAG_MASK | RANSNx16Params.ORDER_FLAG_MASK); + final List testCases = new ArrayList<>(); + CRAMInteropTestUtils.getInteropRawTestFiles() + .forEach(path -> + ransNx16ParamsFormatFlagList.stream().map(ransNx16ParamsFormatFlag -> new Object[]{ + path, + new RANSNx16Encode(), + new RANSNx16Decode(), + new RANSNx16Params(ransNx16ParamsFormatFlag) + }).forEach(testCases::add)); + return testCases.toArray(new Object[][]{}); + } + + // uses the available compressed interop test files + public Object[][] get4x8DecodeOnlyTestCases() throws IOException { + + // params: + // compressed testfile path, uncompressed testfile path, + // RANS decoder + final List testCases = new ArrayList<>(); + for (Path path : CRAMInteropTestUtils.getInteropCompressedFilePaths(COMPRESSED_RANS4X8_DIR)) { + Object[] objects = new Object[]{ + path, + CRAMInteropTestUtils.getUnCompressedFilePath(path), + new RANS4x8Decode() + }; + testCases.add(objects); + } + return testCases.toArray(new Object[][]{}); + } + + // uses the available compressed interop test files + public Object[][] getNx16DecodeOnlyTestCases() throws IOException { + + // params: + // compressed testfile path, uncompressed testfile path, + // RANS decoder + final List testCases = new ArrayList<>(); + for (Path path : CRAMInteropTestUtils.getInteropCompressedFilePaths(COMPRESSED_RANSNX16_DIR)) { + Object[] objects = new Object[]{ + path, + CRAMInteropTestUtils.getUnCompressedFilePath(path), + new RANSNx16Decode() + }; + testCases.add(objects); + } + return testCases.toArray(new Object[][]{}); + } + + @DataProvider(name = "roundTripTestCases") + public Object[][] getRoundTripTestCases() throws IOException { + + // params: + // uncompressed testfile path, + // RANS encoder, RANS decoder, RANS params + return Stream.concat(Arrays.stream(get4x8RoundTripTestCases()), Arrays.stream(getNx16RoundTripTestCases())) + .toArray(Object[][]::new); + } + + @DataProvider(name = "decodeOnlyTestCases") + public Object[][] getDecodeOnlyTestCases() throws IOException { + + // params: + // compressed testfile path, uncompressed testfile path, + // RANS decoder + return Stream.concat(Arrays.stream(get4x8DecodeOnlyTestCases()), Arrays.stream(getNx16DecodeOnlyTestCases())) + .toArray(Object[][]::new); + } + + @Test(description = "Test if CRAM Interop Test Data is available") + public void testHtsCodecsCorpusIsAvailable() { + if (!CRAMInteropTestUtils.isInteropTestDataAvailable()) { + throw new SkipException(String.format("CRAM Interop Test Data is not available at %s", + CRAMInteropTestUtils.INTEROP_TEST_FILES_PATH)); + } + } + + @Test ( + dependsOnMethods = "testHtsCodecsCorpusIsAvailable", + dataProvider = "roundTripTestCases", + description = "Roundtrip using htsjdk RANS. Compare the output with the original file" ) + public void testRANSRoundTrip( + final Path uncompressedFilePath, + final RANSEncode ransEncode, + final RANSDecode ransDecode, + final RANSParams params) throws IOException { + try (final InputStream uncompressedInteropStream = Files.newInputStream(uncompressedFilePath)) { + + // preprocess the uncompressed data (to match what the htscodecs-library test harness does) + // by filtering out the embedded newlines, and then round trip through RANS and compare the + // results + final ByteBuffer uncompressedInteropBytes = CompressionUtils.wrap(CRAMInteropTestUtils.filterEmbeddedNewlines(IOUtils.toByteArray(uncompressedInteropStream))); + + // Stripe Flag is not implemented in RANSNx16 Encoder. + // The encoder throws CRAMException if Stripe Flag is used. + if (params instanceof RANSNx16Params && ((RANSNx16Params) params).isStripe()) { + Assert.assertThrows(CRAMException.class, () -> ransEncode.compress(uncompressedInteropBytes, params)); + } else { + final ByteBuffer compressedHtsjdkBytes = ransEncode.compress(uncompressedInteropBytes, params); + uncompressedInteropBytes.rewind(); + Assert.assertEquals(ransDecode.uncompress(compressedHtsjdkBytes), uncompressedInteropBytes); + } + } + } + + @Test ( + dependsOnMethods = "testHtsCodecsCorpusIsAvailable", + dataProvider = "decodeOnlyTestCases", + description = "Uncompress the existing compressed file using htsjdk RANS and compare it with the original file.") + public void testDecodeOnly( + final Path compressedFilePath, + final Path uncompressedInteropPath, + final RANSDecode ransDecode) throws IOException { + try (final InputStream uncompressedInteropStream = Files.newInputStream(uncompressedInteropPath); + final InputStream preCompressedInteropStream = Files.newInputStream(compressedFilePath) + ) { + + // preprocess the uncompressed data (to match what the htscodecs-library test harness does) + // by filtering out the embedded newlines, and then round trip through RANS and compare the + // results + final ByteBuffer uncompressedInteropBytes = CompressionUtils.wrap(CRAMInteropTestUtils.filterEmbeddedNewlines(IOUtils.toByteArray(uncompressedInteropStream))); + final ByteBuffer preCompressedInteropBytes = CompressionUtils.wrap(IOUtils.toByteArray(preCompressedInteropStream)); + + // Use htsjdk to uncompress the precompressed file from htscodecs repo + final ByteBuffer uncompressedHtsjdkBytes = ransDecode.uncompress(preCompressedInteropBytes); + + // Compare the htsjdk uncompressed bytes with the original input file from htscodecs repo + Assert.assertEquals(uncompressedHtsjdkBytes, uncompressedInteropBytes); + } catch (final NoSuchFileException ex){ + throw new SkipException("Skipping testDecodeOnly as either input file " + + "or precompressed file is missing.", ex); + } + } + +} \ No newline at end of file diff --git a/src/test/java/htsjdk/samtools/cram/RangeInteropTest.java b/src/test/java/htsjdk/samtools/cram/RangeInteropTest.java new file mode 100644 index 0000000000..72a88da1fc --- /dev/null +++ b/src/test/java/htsjdk/samtools/cram/RangeInteropTest.java @@ -0,0 +1,145 @@ +package htsjdk.samtools.cram; + +import htsjdk.HtsjdkTest; +import htsjdk.samtools.cram.compression.range.RangeDecode; +import htsjdk.samtools.cram.compression.range.RangeEncode; +import htsjdk.samtools.cram.compression.range.RangeParams; +import org.apache.commons.compress.utils.IOUtils; +import org.testng.Assert; +import org.testng.SkipException; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class RangeInteropTest extends HtsjdkTest { + public static final String COMPRESSED_RANGE_DIR = "arith"; + + // enumerates the different flag combinations + @DataProvider(name = "roundTripTestCases") + public Object[][] getRoundTripTestCases() throws IOException { + + // params: + // uncompressed testfile path, + // Range encoder, Range decoder, Range params + final List rangeParamsFormatFlagList = Arrays.asList( + 0x00, + RangeParams.ORDER_FLAG_MASK, + RangeParams.RLE_FLAG_MASK, + RangeParams.RLE_FLAG_MASK | RangeParams.ORDER_FLAG_MASK, + RangeParams.CAT_FLAG_MASK, + RangeParams.CAT_FLAG_MASK | RangeParams.ORDER_FLAG_MASK, + RangeParams.PACK_FLAG_MASK, + RangeParams.PACK_FLAG_MASK | RangeParams. ORDER_FLAG_MASK, + RangeParams.PACK_FLAG_MASK | RangeParams.RLE_FLAG_MASK, + RangeParams.PACK_FLAG_MASK | RangeParams.RLE_FLAG_MASK | RangeParams.ORDER_FLAG_MASK, + RangeParams.EXT_FLAG_MASK, + RangeParams.EXT_FLAG_MASK | RangeParams.PACK_FLAG_MASK); + final List testCases = new ArrayList<>(); + CRAMInteropTestUtils.getInteropRawTestFiles() + .forEach(path -> + rangeParamsFormatFlagList.stream().map(rangeParamsFormatFlag -> new Object[]{ + path, + new RangeEncode(), + new RangeDecode(), + new RangeParams(rangeParamsFormatFlag) + }).forEach(testCases::add)); + return testCases.toArray(new Object[][]{}); + } + + // uses the available compressed interop test files + @DataProvider(name = "decodeOnlyTestCases") + public Object[][] getDecodeOnlyTestCases() throws IOException { + + // params: + // compressed testfile path, uncompressed testfile path, + // Range decoder + final List testCases = new ArrayList<>(); + for (Path path : CRAMInteropTestUtils.getInteropCompressedFilePaths(COMPRESSED_RANGE_DIR)) { + Object[] objects = new Object[]{ + path, + CRAMInteropTestUtils.getUnCompressedFilePath(path), + new RangeDecode() + }; + testCases.add(objects); + } + return testCases.toArray(new Object[][]{}); + } + + @Test(description = "Test if CRAM Interop Test Data is available") + public void testHtsCodecsCorpusIsAvailable() { + if (!CRAMInteropTestUtils.isInteropTestDataAvailable()) { + throw new SkipException(String.format("CRAM Interop Test Data is not available at %s", + CRAMInteropTestUtils.INTEROP_TEST_FILES_PATH)); + } + } + + @Test ( + dependsOnMethods = "testHtsCodecsCorpusIsAvailable", + dataProvider = "roundTripTestCases", + description = "Roundtrip using htsjdk Range Codec. Compare the output with the original file" ) + public void testRangeRoundTrip( + final Path uncompressedFilePath, + final RangeEncode rangeEncode, + final RangeDecode rangeDecode, + final RangeParams params) throws IOException { + try (final InputStream uncompressedInteropStream = Files.newInputStream(uncompressedFilePath)) { + + // preprocess the uncompressed data (to match what the htscodecs-library test harness does) + // by filtering out the embedded newlines, and then round trip through Range codec and compare the + // results + final ByteBuffer uncompressedInteropBytes = ByteBuffer.wrap(CRAMInteropTestUtils.filterEmbeddedNewlines(IOUtils.toByteArray(uncompressedInteropStream))); + + if (params.isStripe()) { + Assert.assertThrows(CRAMException.class, () -> rangeEncode.compress(uncompressedInteropBytes, params)); + } else { + final ByteBuffer compressedHtsjdkBytes = rangeEncode.compress(uncompressedInteropBytes, params); + uncompressedInteropBytes.rewind(); + Assert.assertEquals(rangeDecode.uncompress(compressedHtsjdkBytes), uncompressedInteropBytes); + } + } + } + + @Test ( + dependsOnMethods = "testHtsCodecsCorpusIsAvailable", + dataProvider = "decodeOnlyTestCases", + description = "Uncompress the existing compressed file using htsjdk Range codec and compare it with the original file.") + public void testDecodeOnly( + final Path compressedFilePath, + final Path uncompressedInteropPath, + final RangeDecode rangeDecode) throws IOException { + try (final InputStream uncompressedInteropStream = Files.newInputStream(uncompressedInteropPath); + final InputStream preCompressedInteropStream = Files.newInputStream(compressedFilePath) + ) { + // preprocess the uncompressed data (to match what the htscodecs-library test harness does) + // by filtering out the embedded newlines, and then round trip through Range codec + // and compare the results + + final ByteBuffer uncompressedInteropBytes; + if (uncompressedInteropPath.toString().contains("htscodecs/tests/dat/u")) { + uncompressedInteropBytes = ByteBuffer.wrap(IOUtils.toByteArray(uncompressedInteropStream)); + } else { + uncompressedInteropBytes = ByteBuffer.wrap(CRAMInteropTestUtils.filterEmbeddedNewlines(IOUtils.toByteArray(uncompressedInteropStream))); + } + final ByteBuffer preCompressedInteropBytes = ByteBuffer.wrap(IOUtils.toByteArray(preCompressedInteropStream)); + + // Use htsjdk to uncompress the precompressed file from htscodecs repo + final ByteBuffer uncompressedHtsjdkBytes = rangeDecode.uncompress(preCompressedInteropBytes); + + // Compare the htsjdk uncompressed bytes with the original input file from htscodecs repo + Assert.assertEquals(uncompressedHtsjdkBytes, uncompressedInteropBytes); + } catch (final NoSuchFileException ex){ + throw new SkipException("Skipping testDecodeOnly as either input file " + + "or precompressed file is missing.", ex); + } + } + +} \ No newline at end of file diff --git a/src/test/java/htsjdk/samtools/cram/build/SliceFactoryTest.java b/src/test/java/htsjdk/samtools/cram/build/SliceFactoryTest.java index cee032bd6c..72316c47c9 100644 --- a/src/test/java/htsjdk/samtools/cram/build/SliceFactoryTest.java +++ b/src/test/java/htsjdk/samtools/cram/build/SliceFactoryTest.java @@ -4,7 +4,6 @@ import htsjdk.samtools.SAMFileHeader; import htsjdk.samtools.SAMRecord; import htsjdk.samtools.cram.CRAMException; -import htsjdk.samtools.cram.compression.rans.RANS; import htsjdk.samtools.cram.ref.ReferenceContext; import htsjdk.samtools.cram.structure.CRAMEncodingStrategy; import htsjdk.samtools.cram.structure.CRAMStructureTestHelper; diff --git a/src/test/java/htsjdk/samtools/cram/compression/CompressorCacheTest.java b/src/test/java/htsjdk/samtools/cram/compression/CompressorCacheTest.java index 20e84ed22f..a4f684fb11 100644 --- a/src/test/java/htsjdk/samtools/cram/compression/CompressorCacheTest.java +++ b/src/test/java/htsjdk/samtools/cram/compression/CompressorCacheTest.java @@ -2,6 +2,7 @@ import htsjdk.HtsjdkTest; import htsjdk.samtools.Defaults; +import htsjdk.samtools.cram.compression.range.RangeParams; import htsjdk.samtools.cram.structure.CompressorCache; import htsjdk.samtools.cram.structure.block.BlockCompressionMethod; import org.testng.Assert; @@ -30,6 +31,19 @@ public Object[][] cachedCompressorForMethodPositiveTests() { {BlockCompressionMethod.RANS, 1, RANSExternalCompressor.class}, {BlockCompressionMethod.RANS, 0, RANSExternalCompressor.class}, {BlockCompressionMethod.RANS, ExternalCompressor.NO_COMPRESSION_ARG, RANSExternalCompressor.class}, + {BlockCompressionMethod.RANGE, 0x00, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.ORDER_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.RLE_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.RLE_FLAG_MASK | RangeParams.ORDER_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.CAT_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.CAT_FLAG_MASK | RangeParams.ORDER_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.PACK_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.PACK_FLAG_MASK | RangeParams. ORDER_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.PACK_FLAG_MASK | RangeParams.RLE_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.PACK_FLAG_MASK | RangeParams.RLE_FLAG_MASK | RangeParams.ORDER_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.EXT_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.EXT_FLAG_MASK | RangeParams.PACK_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, ExternalCompressor.NO_COMPRESSION_ARG, RangeExternalCompressor.class}, }; } @@ -63,4 +77,4 @@ public void testGetCompressorForMethodNegative( final int compressorSpecificArg) { compressorCache.getCompressorForMethod(method, compressorSpecificArg); } -} +} \ No newline at end of file diff --git a/src/test/java/htsjdk/samtools/cram/compression/ExternalCompressionTest.java b/src/test/java/htsjdk/samtools/cram/compression/ExternalCompressionTest.java index 84375ea84a..252a7ef8b2 100644 --- a/src/test/java/htsjdk/samtools/cram/compression/ExternalCompressionTest.java +++ b/src/test/java/htsjdk/samtools/cram/compression/ExternalCompressionTest.java @@ -2,6 +2,7 @@ import htsjdk.HtsjdkTest; import htsjdk.samtools.Defaults; +import htsjdk.samtools.cram.compression.range.RangeParams; import htsjdk.samtools.cram.structure.block.BlockCompressionMethod; import org.testng.Assert; import org.testng.annotations.DataProvider; @@ -29,6 +30,20 @@ public Object[][] compressorForMethodPositiveTests() { {BlockCompressionMethod.RANS, 1, RANSExternalCompressor.class}, {BlockCompressionMethod.RANS, 0, RANSExternalCompressor.class}, {BlockCompressionMethod.RANS, ExternalCompressor.NO_COMPRESSION_ARG, RANSExternalCompressor.class}, + {BlockCompressionMethod.RANGE, 1, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, 0x00, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.ORDER_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.RLE_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.RLE_FLAG_MASK | RangeParams.ORDER_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.CAT_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.CAT_FLAG_MASK | RangeParams.ORDER_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.PACK_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.PACK_FLAG_MASK | RangeParams. ORDER_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.PACK_FLAG_MASK | RangeParams.RLE_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.PACK_FLAG_MASK | RangeParams.RLE_FLAG_MASK | RangeParams.ORDER_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.EXT_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, RangeParams.EXT_FLAG_MASK | RangeParams.PACK_FLAG_MASK, RangeExternalCompressor.class}, + {BlockCompressionMethod.RANGE, ExternalCompressor.NO_COMPRESSION_ARG, RangeExternalCompressor.class}, }; } @@ -82,4 +97,4 @@ public void testBZip2Decompression() throws IOException { Assert.assertEquals(output, "BZip2 worked".getBytes()); } -} +} \ No newline at end of file diff --git a/src/test/java/htsjdk/samtools/cram/compression/nametokenisation/NameTokenisationTest.java b/src/test/java/htsjdk/samtools/cram/compression/nametokenisation/NameTokenisationTest.java new file mode 100644 index 0000000000..29e487f64a --- /dev/null +++ b/src/test/java/htsjdk/samtools/cram/compression/nametokenisation/NameTokenisationTest.java @@ -0,0 +1,100 @@ +package htsjdk.samtools.cram.compression.nametokenisation; + +import htsjdk.HtsjdkTest; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +public class NameTokenisationTest extends HtsjdkTest { + + private static class TestDataEnvelope { + public final byte[] testArray; + public TestDataEnvelope(final byte[] testdata) { + this.testArray = testdata; + } + public String toString() { + return String.format("Array of size %d", testArray.length); + } + } + + @DataProvider(name="nameTokenisation") + public Object[][] getNameTokenisationTestData() { + + List readNamesList = new ArrayList<>(); + readNamesList.add(""); + + // a subset of read names from + // src/test/resources/htsjdk/samtools/cram/CEUTrio.HiSeq.WGS.b37.NA12878.20.first.8000.bam + readNamesList.add("20FUKAAXX100202:6:27:4968:125377\n" + + "20FUKAAXX100202:6:27:4986:125375\n" + + "20FUKAAXX100202:5:62:8987:1929\n" + + "20GAVAAXX100126:1:28:4295:139802\n" + + "20FUKAAXX100202:4:23:8516:117251\n" + + "20FUKAAXX100202:6:23:6442:37469\n" + + "20FUKAAXX100202:8:24:10477:24196\n" + + "20GAVAAXX100126:8:63:5797:158250\n" + + "20FUKAAXX100202:1:45:12798:104365\n" + + "20GAVAAXX100126:3:23:6419:199245\n" + + "20FUKAAXX100202:8:48:6663:137967\n" + + "20FUKAAXX100202:6:68:17726:162601"); + + // a subset of read names from + // src/test/resources/htsjdk/samtools/longreads/NA12878.m64020_190210_035026.chr21.5011316.5411316.unmapped.bam + readNamesList.add("m64020_190210_035026/44368402/ccs\n"); + readNamesList.add("m64020_190210_035026/44368402/ccs"); + readNamesList.add("m64020_190210_035026/44368402/ccs\n" + + "m64020_190210_035026/124127126/ccs\n" + + "m64020_190210_035026/4981311/ccs\n" + + "m64020_190210_035026/80022195/ccs\n" + + "m64020_190210_035026/17762104/ccs\n" + + "m64020_190210_035026/62981096/ccs\n" + + "m64020_190210_035026/86968803/ccs\n" + + "m64020_190210_035026/46400955/ccs\n" + + "m64020_190210_035026/137561592/ccs\n" + + "m64020_190210_035026/52233471/ccs\n" + + "m64020_190210_035026/97127189/ccs\n" + + "m64020_190210_035026/115278035/ccs\n" + + "m64020_190210_035026/155256324/ccs\n" + + "m64020_190210_035026/163644151/ccs\n" + + "m64020_190210_035026/162728365/ccs\n" + + "m64020_190210_035026/160238116/ccs\n" + + "m64020_190210_035026/147719983/ccs\n" + + "m64020_190210_035026/60883331/ccs\n" + + "m64020_190210_035026/1116165/ccs\n" + + "m64020_190210_035026/75893199/ccs"); + + // source: https://gatk.broadinstitute.org/hc/en-us/articles/360035890671-Read-groups + readNamesList.add( + "H0164ALXX140820:2:1101:10003:23460\n" + + "H0164ALXX140820:2:1101:15118:25288"); + + final List testCases = new ArrayList<>(); + for (String readName : readNamesList) { + Object[] objects = new Object[]{ + new NameTokenisationEncode(), + new NameTokenisationDecode(), + new TestDataEnvelope(readName.getBytes())}; + testCases.add(objects); + } + return testCases.toArray(new Object[][]{}); + } + + @Test(dataProvider = "nameTokenisation") + public void testRoundTrip( + final NameTokenisationEncode nameTokenisationEncode, + final NameTokenisationDecode nameTokenisationDecode, + final TestDataEnvelope td) { + ByteBuffer uncompressedBuffer = ByteBuffer.wrap(td.testArray); + ByteBuffer compressedBuffer = nameTokenisationEncode.compress(uncompressedBuffer, 0); + String decompressedNames = nameTokenisationDecode.uncompress(compressedBuffer); + ByteBuffer decompressedNamesBuffer = StandardCharsets.UTF_8.encode(decompressedNames); + uncompressedBuffer.rewind(); + Assert.assertEquals(decompressedNamesBuffer, uncompressedBuffer); + } + +} \ No newline at end of file diff --git a/src/test/java/htsjdk/samtools/cram/compression/range/RangeTest.java b/src/test/java/htsjdk/samtools/cram/compression/range/RangeTest.java new file mode 100644 index 0000000000..37a9081574 --- /dev/null +++ b/src/test/java/htsjdk/samtools/cram/compression/range/RangeTest.java @@ -0,0 +1,228 @@ +package htsjdk.samtools.cram.compression.range; + +import htsjdk.HtsjdkTest; +import htsjdk.samtools.cram.CRAMException; +import htsjdk.samtools.cram.compression.CompressionUtils; +import htsjdk.samtools.util.TestUtil; +import htsjdk.utils.TestNGUtils; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +public class RangeTest extends HtsjdkTest { + private final Random random = new Random(TestUtil.RANDOM_SEED); + + private static class TestDataEnvelope { + public final byte[] testArray; + public TestDataEnvelope(final byte[] testdata) { + this.testArray = testdata; + } + public String toString() { + return String.format("Array of size %d", testArray.length); + } + } + + public Object[][] getRangeEmptyTestData() { + return new Object[][]{ + { new TestDataEnvelope(new byte[]{}) }, + }; + } + + @DataProvider(name = "rangeTestData") + public Object[][] getRangeTestData() { + return new Object[][] { + { new TestDataEnvelope(new byte[] {0}) }, + { new TestDataEnvelope(new byte[] {0, 1}) }, + { new TestDataEnvelope(new byte[] {0, 1, 2}) }, + { new TestDataEnvelope(new byte[] {0, 1, 2, 3}) }, + { new TestDataEnvelope(new byte[1000]) }, + { new TestDataEnvelope(getNBytesWithValues(1000, (n, index) -> (byte) 1)) }, + { new TestDataEnvelope(getNBytesWithValues(1000, (n, index) -> Byte.MIN_VALUE)) }, + { new TestDataEnvelope(getNBytesWithValues(1000, (n, index) -> Byte.MAX_VALUE)) }, + { new TestDataEnvelope(getNBytesWithValues(1000, (n, index) -> (byte) index.intValue())) }, + { new TestDataEnvelope(getNBytesWithValues(1000, (n, index) -> index < n / 2 ? (byte) 0 : (byte) 1)) }, + { new TestDataEnvelope(getNBytesWithValues(1000, (n, index) -> index < n % 2 ? (byte) 0 : (byte) 1)) }, + { new TestDataEnvelope(randomBytesFromGeometricDistribution(1000, 0.1)) }, + { new TestDataEnvelope(randomBytesFromGeometricDistribution(1000, 0.01)) }, + { new TestDataEnvelope(randomBytesFromGeometricDistribution(10 * 1000 * 1000 + 1, 0.01)) }, + }; + } + + public Object[][] getRangeTestDataTinySmallLarge() { + + // params: test data, lower limit, upper limit + return new Object[][]{ + { new TestDataEnvelope(randomBytesFromGeometricDistribution(100, 0.1)), 1, 100 }, // Tiny + { new TestDataEnvelope(randomBytesFromGeometricDistribution(1000, 0.01)), 4, 1000 }, // Small + { new TestDataEnvelope(randomBytesFromGeometricDistribution(100 * 1000 + 3, 0.01)), 100 * 1000 + 3 - 4, 100 * 1000 + 3 } // Large + }; + } + + @DataProvider(name="rangeCodecs") + public Object[][] getRangeCodecs() { + + // params: RangeEncoder, RangeDecoder, RangeParams + final List rangeParamsFormatFlagList = Arrays.asList( + 0x00, + RangeParams.ORDER_FLAG_MASK, + RangeParams.RLE_FLAG_MASK, + RangeParams.RLE_FLAG_MASK | RangeParams.ORDER_FLAG_MASK, + RangeParams.CAT_FLAG_MASK, + RangeParams.CAT_FLAG_MASK | RangeParams.ORDER_FLAG_MASK, + RangeParams.PACK_FLAG_MASK, + RangeParams.PACK_FLAG_MASK | RangeParams. ORDER_FLAG_MASK, + RangeParams.PACK_FLAG_MASK | RangeParams.RLE_FLAG_MASK, + RangeParams.PACK_FLAG_MASK | RangeParams.RLE_FLAG_MASK | RangeParams.ORDER_FLAG_MASK, + RangeParams.EXT_FLAG_MASK, + RangeParams.EXT_FLAG_MASK | RangeParams.PACK_FLAG_MASK); + final List testCases = new ArrayList<>(); + for (Integer rangeParamsFormatFlag : rangeParamsFormatFlagList) { + Object[] objects = new Object[]{ + new RangeEncode(), + new RangeDecode(), + new RangeParams(rangeParamsFormatFlag) + }; + testCases.add(objects); + } + return testCases.toArray(new Object[][]{}); + } + + public Object[][] getRangeDecodeOnlyCodecs() { + // params: Range encoder, Range decoder, Range params + final List rangeParamsFormatFlagList = Arrays.asList( + RangeParams.STRIPE_FLAG_MASK, + RangeParams.STRIPE_FLAG_MASK | RangeParams.ORDER_FLAG_MASK); + final List testCases = new ArrayList<>(); + for (Integer rangeParamsFormatFlag : rangeParamsFormatFlagList) { + Object[] objects = new Object[]{ + new RangeEncode(), + new RangeDecode(), + new RangeParams(rangeParamsFormatFlag) + }; + testCases.add(objects); + } + return testCases.toArray(new Object[][]{}); + } + + @DataProvider(name="RangeDecodeOnlyAndData") + public Object[][] getRangeDecodeOnlyAndData() { + + // params: Range encoder, Range decoder, Range params, test data + // this data provider provides all the non-empty testdata input for Range codec + return TestNGUtils.cartesianProduct(getRangeDecodeOnlyCodecs(), getRangeTestData()); + } + + @DataProvider(name="allRangeCodecsAndData") + public Object[][] getAllRangeCodecsAndData() { + + // params: RangeEncode, RangeDecode, RangeParams, test data + // this data provider provides all the testdata for all of Range codecs + return Stream.concat( + Arrays.stream(TestNGUtils.cartesianProduct(getRangeCodecs(), getRangeTestData())), + Arrays.stream(TestNGUtils.cartesianProduct(getRangeCodecs(), getRangeEmptyTestData()))) + .toArray(Object[][]::new); + } + + @DataProvider(name="allRangeCodecsAndDataForTinySmallLarge") + public Object[][] allRangeCodecsAndDataForTinySmallLarge() { + + // params: RangeEncode, RangeDecode, RangeParams, test data, lower limit, upper limit + // this data provider provides Tiny, Small and Large testdata for all of Range codecs + return TestNGUtils.cartesianProduct(getRangeCodecs(), getRangeTestDataTinySmallLarge()); + } + + @Test(dataProvider = "allRangeCodecsAndData") + public void testRoundTrip(final RangeEncode rangeEncode, + final RangeDecode rangeDecode, + final RangeParams rangeParams, + final TestDataEnvelope td) { + rangeRoundTrip(rangeEncode, rangeDecode, rangeParams, CompressionUtils.wrap(td.testArray)); + } + + @Test(dataProvider = "allRangeCodecsAndDataForTinySmallLarge") + public void testRoundTripTinySmallLarge( + final RangeEncode rangeEncode, + final RangeDecode rangeDecode, + final RangeParams rangeParams, + final TestDataEnvelope td, + final Integer lowerLimit, + final Integer upperLimit){ + final ByteBuffer in = CompressionUtils.wrap(td.testArray); + for (int size = lowerLimit; size < upperLimit; size++) { + in.position(0); + in.limit(size); + rangeRoundTrip(rangeEncode, rangeDecode, rangeParams, in); + } + } + + @Test( + dataProvider = "RangeDecodeOnlyAndData", + expectedExceptions = { CRAMException.class }, + expectedExceptionsMessageRegExp = "Range Encoding with Stripe Flag is not implemented.") + public void testRangeEncodeStripe( + final RangeEncode rangeEncode, + final RangeDecode unused, + final RangeParams params, + final TestDataEnvelope td) { + + // When td is not Empty, Encoding with Stripe Flag should throw an Exception + // as Encode Stripe is not implemented + final ByteBuffer compressed = rangeEncode.compress(CompressionUtils.wrap(td.testArray), params); + } + + // testRangeBuffersMeetBoundaryExpectations + // testRangeHeader + // testRangeEncodeStripe + + private static void rangeRoundTrip( + final RangeEncode rangeEncode, + final RangeDecode rangeDecode, + final RangeParams rangeParams, + final ByteBuffer data) { + final ByteBuffer compressed = rangeEncode.compress(data, rangeParams); + final ByteBuffer uncompressed = rangeDecode.uncompress(compressed); + data.rewind(); + Assert.assertEquals(data, uncompressed); + } + +// TODO: Add to utils + private byte[] getNBytesWithValues(final int n, final BiFunction valueForIndex) { + final byte[] data = new byte[n]; + for (int i = 0; i < data.length; i++) { + data[i] = valueForIndex.apply(n, i); + } + return data; + } + // TODO: Add to utils + private byte[] randomBytesFromGeometricDistribution(final int size, final double p) { + final byte[] data = new byte[size]; + for (int i = 0; i < data.length; i++) { + data[i] = drawByteFromGeometricDistribution(p); + } + return data; + } + + /** + * A crude implementation of RNG for sampling geometric distribution. The + * value returned is offset by -1 to include zero. For testing purposes + * only, no refunds! + * + * @param probability the probability of success + * @return an almost random byte value. + */ + // TODO: Add to utils + private byte drawByteFromGeometricDistribution(final double probability) { + final double rand = random.nextDouble(); + final double g = Math.ceil(Math.log(1 - rand) / Math.log(1 - probability)) - 1; + return (byte) g; + } + +} \ No newline at end of file diff --git a/src/test/java/htsjdk/samtools/cram/compression/rans/RansTest.java b/src/test/java/htsjdk/samtools/cram/compression/rans/RansTest.java index 9c5b7c5752..9495d826ef 100644 --- a/src/test/java/htsjdk/samtools/cram/compression/rans/RansTest.java +++ b/src/test/java/htsjdk/samtools/cram/compression/rans/RansTest.java @@ -1,26 +1,38 @@ package htsjdk.samtools.cram.compression.rans; import htsjdk.HtsjdkTest; +import htsjdk.samtools.cram.CRAMException; +import htsjdk.samtools.cram.compression.CompressionUtils; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Decode; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Encode; +import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Params; +import htsjdk.samtools.cram.compression.rans.ransnx16.RANSNx16Decode; +import htsjdk.samtools.cram.compression.rans.ransnx16.RANSNx16Encode; +import htsjdk.samtools.cram.compression.rans.ransnx16.RANSNx16Params; import htsjdk.samtools.util.TestUtil; +import htsjdk.utils.TestNGUtils; import org.testng.Assert; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; - import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.Random; import java.util.function.BiFunction; +import java.util.stream.Stream; /** * Created by vadim on 22/04/2015. */ public class RansTest extends HtsjdkTest { - private Random random = new Random(TestUtil.RANDOM_SEED); + private final Random random = new Random(TestUtil.RANDOM_SEED); // Since some of our test cases use very large byte arrays, so enclose them in a wrapper class since // otherwise IntelliJ serializes them to strings for display in the test output, which is *super*-slow. - private static class TestCaseWrapper { + private static class TestDataEnvelope { public final byte[] testArray; - public TestCaseWrapper(final byte[] testdata) { + public TestDataEnvelope(final byte[] testdata) { this.testArray = testdata; } public String toString() { @@ -28,140 +40,269 @@ public String toString() { } } - @DataProvider(name="ransData") + public Object[][] getRansEmptyTestData() { + return new Object[][]{ + { new TestDataEnvelope(new byte[]{}) }, + }; + } + public Object[][] getRansTestData() { return new Object[][] { - { new TestCaseWrapper(new byte[]{}) }, - { new TestCaseWrapper(new byte[] {0}) }, - { new TestCaseWrapper(new byte[] {0, 1}) }, - { new TestCaseWrapper(new byte[] {0, 1, 2}) }, - { new TestCaseWrapper(new byte[] {0, 1, 2, 3}) }, - { new TestCaseWrapper(new byte[1000]) }, - { new TestCaseWrapper(getNBytesWithValues(1000, (n, index) -> (byte) 1)) }, - { new TestCaseWrapper(getNBytesWithValues(1000, (n, index) -> Byte.MIN_VALUE)) }, - { new TestCaseWrapper(getNBytesWithValues(1000, (n, index) -> Byte.MAX_VALUE)) }, - { new TestCaseWrapper(getNBytesWithValues(1000, (n, index) -> (byte) index.intValue())) }, - { new TestCaseWrapper(getNBytesWithValues(1000, (n, index) -> index < n / 2 ? (byte) 0 : (byte) 1)) }, - { new TestCaseWrapper(getNBytesWithValues(1000, (n, index) -> index < n % 2 ? (byte) 0 : (byte) 1)) }, - { new TestCaseWrapper(randomBytesFromGeometricDistribution(1000, 0.1)) }, - { new TestCaseWrapper(randomBytesFromGeometricDistribution(1000, 0.01)) }, - { new TestCaseWrapper(randomBytesFromGeometricDistribution(10 * 1000 * 1000 + 1, 0.01)) }, + { new TestDataEnvelope(new byte[] {0}) }, + { new TestDataEnvelope(new byte[] {0, 1}) }, + { new TestDataEnvelope(new byte[] {0, 1, 2}) }, + { new TestDataEnvelope(new byte[] {0, 1, 2, 3}) }, + { new TestDataEnvelope(new byte[] {1, 2, 3, 4}) }, + { new TestDataEnvelope(new byte[] {1, 2, 3, 4, 5}) }, + { new TestDataEnvelope(new byte[1000]) }, + { new TestDataEnvelope(getNBytesWithValues(1000, (n, index) -> (byte) 1)) }, + { new TestDataEnvelope(getNBytesWithValues(1000, (n, index) -> Byte.MIN_VALUE)) }, + { new TestDataEnvelope(getNBytesWithValues(1000, (n, index) -> Byte.MAX_VALUE)) }, + { new TestDataEnvelope(getNBytesWithValues(1000, (n, index) -> (byte) index.intValue())) }, + { new TestDataEnvelope(getNBytesWithValues(1000, (n, index) -> index < n / 2 ? (byte) 0 : (byte) 1)) }, + { new TestDataEnvelope(getNBytesWithValues(1000, (n, index) -> index < n % 2 ? (byte) 0 : (byte) 1)) }, + { new TestDataEnvelope(randomBytesFromGeometricDistribution(10, 0.1)) }, + { new TestDataEnvelope(randomBytesFromGeometricDistribution(31, 0.1)) }, + { new TestDataEnvelope(randomBytesFromGeometricDistribution(32, 0.1)) }, + { new TestDataEnvelope(randomBytesFromGeometricDistribution(33, 0.1)) }, + { new TestDataEnvelope(randomBytesFromGeometricDistribution(1000, 0.1)) }, + { new TestDataEnvelope(randomBytesFromGeometricDistribution(1000, 0.01)) }, + { new TestDataEnvelope(randomBytesFromGeometricDistribution(10 * 1000 * 1000 + 1, 0.01)) }, }; } - @Test(dataProvider="ransData") - public void testRANS(final TestCaseWrapper tc) { - roundTripForEachOrder(tc.testArray); + public Object[][] getRansTestDataTinySmallLarge() { + + // params: test data, lower limit, upper limit + return new Object[][]{ + { new TestDataEnvelope(randomBytesFromGeometricDistribution(100, 0.1)), 1, 100 }, // Tiny + { new TestDataEnvelope(randomBytesFromGeometricDistribution(1000, 0.01)), 4, 1000 }, // Small + { new TestDataEnvelope(randomBytesFromGeometricDistribution(100 * 1000 + 3, 0.01)), 100 * 1000 + 3 - 4, 100 * 1000 + 3 } // Large + }; } - @Test - public void testSizeRangeTiny() { - for (int i = 0; i < 20; i++) { - final byte[] data = randomBytesFromGeometricDistribution(100, 0.1); - final ByteBuffer in = ByteBuffer.wrap(data); - for (int size = 1; size < data.length; size++) { - in.position(0); - in.limit(size); - roundTripForEachOrder(in); - } - } + @DataProvider(name="rans4x8") + public Object[][] getRans4x8Codecs() { + + // params: RANS encoder, RANS decoder, RANS params + return new Object[][]{ + {new RANS4x8Encode(), new RANS4x8Decode(), new RANS4x8Params(RANSParams.ORDER.ZERO)}, + {new RANS4x8Encode(), new RANS4x8Decode(), new RANS4x8Params(RANSParams.ORDER.ONE)} + }; } - @Test - public void testSizeRangeSmall() { - final byte[] data = randomBytesFromGeometricDistribution(1000, 0.01); - final ByteBuffer in = ByteBuffer.wrap(data); - for (int size = 4; size < data.length; size++) { - in.position(0); - in.limit(size); - roundTripForEachOrder(in); + @DataProvider(name="ransNx16") + public Object[][] getRansNx16Codecs() { + + // params: RANS encoder, RANS decoder, RANS params + final List ransNx16ParamsFormatFlagList = Arrays.asList( + 0x00, + RANSNx16Params.ORDER_FLAG_MASK, + RANSNx16Params.N32_FLAG_MASK, + RANSNx16Params.N32_FLAG_MASK | RANSNx16Params.ORDER_FLAG_MASK, + RANSNx16Params.CAT_FLAG_MASK, + RANSNx16Params.CAT_FLAG_MASK | RANSNx16Params.ORDER_FLAG_MASK, + RANSNx16Params.CAT_FLAG_MASK | RANSNx16Params.N32_FLAG_MASK, + RANSNx16Params.CAT_FLAG_MASK | RANSNx16Params.N32_FLAG_MASK | RANSNx16Params.ORDER_FLAG_MASK, + RANSNx16Params.RLE_FLAG_MASK, + RANSNx16Params.RLE_FLAG_MASK | RANSNx16Params.ORDER_FLAG_MASK, + RANSNx16Params.PACK_FLAG_MASK, + RANSNx16Params.PACK_FLAG_MASK | RANSNx16Params.ORDER_FLAG_MASK, + RANSNx16Params.RLE_FLAG_MASK | RANSNx16Params.PACK_FLAG_MASK, + RANSNx16Params.RLE_FLAG_MASK | RANSNx16Params.PACK_FLAG_MASK | RANSNx16Params.ORDER_FLAG_MASK + ); + final List testCases = new ArrayList<>(); + for (Integer ransNx16ParamsFormatFlag : ransNx16ParamsFormatFlagList) { + final Object[] objects = new Object[]{ + new RANSNx16Encode(), + new RANSNx16Decode(), + new RANSNx16Params(ransNx16ParamsFormatFlag) + }; + testCases.add(objects); } + return testCases.toArray(new Object[][]{}); } - @Test - public void testLargeSize() { - final int size = 100 * 1000 + 3; - final byte[] data = randomBytesFromGeometricDistribution(size, 0.01); - final ByteBuffer in = ByteBuffer.wrap(data); - for (int limit = size - 4; limit < size; limit++) { - in.position(0); - in.limit(limit); - roundTripForEachOrder(in); - } + public Object[][] getRansNx16Encoder() { + + // params: RANS encoder, RANS params + return new Object[][]{ + {new RANSNx16Encode(), new RANSNx16Params(RANSNx16Params.STRIPE_FLAG_MASK)}, + {new RANSNx16Encode(), new RANSNx16Params(RANSNx16Params.ORDER_FLAG_MASK|RANSNx16Params.STRIPE_FLAG_MASK)} + }; } - @Test - public void testBuffersMeetBoundaryExpectations() { - final int size = 1001; - final ByteBuffer raw = ByteBuffer.wrap(randomBytesFromGeometricDistribution(size, 0.01)); - final RANS rans = new RANS(); - for (RANS.ORDER order : RANS.ORDER.values()) { - final ByteBuffer compressed = rans.compress(raw, order); - Assert.assertFalse(raw.hasRemaining()); - Assert.assertEquals(raw.limit(), size); - - Assert.assertEquals(compressed.position(), 0); - Assert.assertTrue(compressed.limit() > 10); - Assert.assertEquals(compressed.get(), (byte) order.ordinal()); - Assert.assertEquals(compressed.getInt(), compressed.limit() - 1 - 4 - 4); - Assert.assertEquals(compressed.getInt(), size); - compressed.rewind(); - - final ByteBuffer uncompressed = rans.uncompress(compressed); - Assert.assertFalse(compressed.hasRemaining()); - Assert.assertEquals(uncompressed.limit(), size); - Assert.assertEquals(uncompressed.position(), 0); - - raw.rewind(); - } + @DataProvider(name="RansNx16RejectEncodeStripe") + public Object[][] getRansNx16RejectEncodeStripe() { + + // params: RANS encoder, RANS decoder, RANS params, test data + // this data provider provides all the non-empty testdata input for RANS Nx16 codec + return TestNGUtils.cartesianProduct(getRansNx16Encoder(), getRansTestData()); } - @Test - public void testRansHeader() { - final byte[] data = randomBytesFromGeometricDistribution(1000, 0.01); - final ByteBuffer compressed = new RANS().compress(ByteBuffer.wrap(data), RANS.ORDER.ZERO); - Assert.assertEquals(compressed.get(), (byte) 0); - Assert.assertEquals(compressed.getInt(), compressed.limit() - 9); - Assert.assertEquals(compressed.getInt(), data.length); + public Object[][] getAllRansCodecs() { + + // params: RANSEncode, RANSDecode, RANSParams + // concatenate RANS4x8 and RANSNx16 codecs + return Stream.concat(Arrays.stream(getRans4x8Codecs()), Arrays.stream(getRansNx16Codecs())) + .toArray(Object[][]::new); } - private byte[] getNBytesWithValues(final int n, final BiFunction valueForIndex) { - final byte[] data = new byte[n]; - for (int i = 0; i < data.length; i++) { - data[i] = valueForIndex.apply(n, i); - } - return data; + @DataProvider(name="allRansAndData") + public Object[][] getAllRansAndData() { + + // params: RANSEncode, RANSDecode, RANSParams, test data + // this data provider provides all the testdata for all of RANS codecs + return Stream.concat( + Arrays.stream(TestNGUtils.cartesianProduct(getAllRansCodecs(), getRansTestData())), + Arrays.stream(TestNGUtils.cartesianProduct(getAllRansCodecs(), getRansEmptyTestData()))) + .toArray(Object[][]::new); } - private static void roundTripForEachOrder(final ByteBuffer data) { - for (RANS.ORDER order : RANS.ORDER.values()) { - roundTripForOrder(data, order); - data.rewind(); - } + @DataProvider(name="allRansAndDataForTinySmallLarge") + public Object[][] getAllRansAndDataForTinySmallLarge() { + + // params: RANSEncode, RANSDecode, RANSParams, test data, lower limit, upper limit + // this data provider provides Tiny, Small and Large testdata for all of RANS codecs + return TestNGUtils.cartesianProduct(getAllRansCodecs(), getRansTestDataTinySmallLarge()); } - private static void roundTripForEachOrder(final byte[] data) { - for (RANS.ORDER order : RANS.ORDER.values()) { - roundTripForOrder(data, order); + @Test(dataProvider = "allRansAndDataForTinySmallLarge") + public void testRoundTripTinySmallLarge( + final RANSEncode ransEncode, + final RANSDecode ransDecode, + final RANSParams params, + final TestDataEnvelope td, + final Integer lowerLimit, + final Integer upperLimit){ + final ByteBuffer in = CompressionUtils.wrap(td.testArray); + for (int rawSize = lowerLimit; rawSize < upperLimit; rawSize++) { + in.position(0); + in.limit(rawSize); + ransRoundTrip(ransEncode, ransDecode, params, in); } } - private static void roundTripForOrder(final ByteBuffer data, final RANS.ORDER order) { - final RANS rans = new RANS(); - final ByteBuffer compressed = rans.compress(data, order); - final ByteBuffer uncompressed = rans.uncompress(compressed); - data.rewind(); - while (data.hasRemaining()) { - if (!uncompressed.hasRemaining()) { - Assert.fail("Premature end of uncompressed data."); + @Test(dataProvider = "rans4x8") + public void testRans4x8BuffersMeetBoundaryExpectations( + final RANS4x8Encode ransEncode, + final RANS4x8Decode ransDecode, + final RANS4x8Params params) { + final int rawSize = 1001; + final ByteBuffer rawData = CompressionUtils.wrap(randomBytesFromGeometricDistribution(rawSize, 0.01)); + final ByteBuffer compressed = ransBufferMeetBoundaryExpectations(rawSize,rawData,ransEncode, ransDecode,params); + Assert.assertTrue(compressed.limit() > Constants.RANS_4x8_PREFIX_BYTE_LENGTH); // minimum prefix len when input is not Empty + Assert.assertEquals(compressed.get(), (byte) params.getOrder().ordinal()); + Assert.assertEquals(compressed.getInt(), compressed.limit() - Constants.RANS_4x8_PREFIX_BYTE_LENGTH); + Assert.assertEquals(compressed.getInt(), rawSize); + } + + @Test(dataProvider = "ransNx16") + public void testRansNx16BuffersMeetBoundaryExpectations( + final RANSNx16Encode ransEncode, + final RANSNx16Decode ransDecode, + final RANSNx16Params params) { + final int rawSize = 1001; + final ByteBuffer rawData = CompressionUtils.wrap(randomBytesFromGeometricDistribution(rawSize, 0.01)); + final ByteBuffer compressed = ransBufferMeetBoundaryExpectations(rawSize,rawData,ransEncode,ransDecode,params); + rawData.rewind(); + Assert.assertTrue(compressed.limit() > 1); // minimum prefix len when input is not Empty + final int FormatFlags = compressed.get() & 0xFF; // first byte of compressed data is the formatFlags + final int[] frequencies = new int[Constants.NUMBER_OF_SYMBOLS]; + final int inSize = rawData.remaining(); + for (int i = 0; i < inSize; i ++) { + frequencies[rawData.get(i) & 0xFF]++; + } + int numSym = 0; + for (int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i++) { + if (frequencies[i]>0) { + numSym++; } - Assert.assertEquals(uncompressed.get(), data.get()); } - Assert.assertFalse(uncompressed.hasRemaining()); + if (params.isPack() & (numSym == 0 | numSym > 16)) { + // In the encoder, Packing is skipped if numSymbols = 0 or numSymbols > 16 + // and the Pack flag is unset in the formatFlags + Assert.assertEquals(FormatFlags, params.getFormatFlags() & ~RANSNx16Params.PACK_FLAG_MASK); + } else { + Assert.assertEquals(FormatFlags, params.getFormatFlags()); + } + // if nosz flag is not set, then the uncompressed size is recorded + if (!params.isNosz()){ + Assert.assertEquals(CompressionUtils.readUint7(compressed), rawSize); + } + } + + @Test(dataProvider="allRansAndData") + public void testRoundTrip( + final RANSEncode ransEncode, + final RANSDecode ransDecode, + final RANSParams params, + final TestDataEnvelope td) { + ransRoundTrip(ransEncode, ransDecode, params, CompressionUtils.wrap(td.testArray)); } - private static void roundTripForOrder(final byte[] data, final RANS.ORDER order) { - roundTripForOrder(ByteBuffer.wrap(data), order); + @Test( + dataProvider = "RansNx16RejectEncodeStripe", + expectedExceptions = { CRAMException.class }, + expectedExceptionsMessageRegExp = "RANSNx16 Encoding with Stripe Flag is not implemented.") + public void testRansNx16RejectEncodeStripe( + final RANSNx16Encode ransEncode, + final RANSNx16Params params, + final TestDataEnvelope td) { + + // When td is not Empty, Encoding with Stripe Flag should throw an Exception + // as Encode Stripe is not implemented + ransEncode.compress(CompressionUtils.wrap(td.testArray), params); + } + + @Test( + description = "RANSNx16 Decoding with Pack Flag if (numSymbols > 16 or numSymbols==0) " + + "should throw CRAMException", + expectedExceptions = { CRAMException.class }, + expectedExceptionsMessageRegExp = "Bit Packing is not permitted when number " + + "of distinct symbols is greater than 16 or equal to 0. Number of distinct symbols: 0") + public void testRANSNx16RejectDecodePack(){ + final ByteBuffer compressedData = CompressionUtils.wrap(new byte[]{(byte) RANSNx16Params.PACK_FLAG_MASK, (byte) 0x00, (byte) 0x00}); + final RANSNx16Decode ransDecode = new RANSNx16Decode(); + ransDecode.uncompress(compressedData); + } + + private static void ransRoundTrip( + final RANSEncode ransEncode, + final RANSDecode ransDecode, + final RANSParams params, + final ByteBuffer data) { + final ByteBuffer compressed = ransEncode.compress(data, params); + final ByteBuffer uncompressed = ransDecode.uncompress(compressed); + data.rewind(); + Assert.assertEquals(data, uncompressed); + } + + public ByteBuffer ransBufferMeetBoundaryExpectations( + final int rawSize, + final ByteBuffer raw, + final RANSEncode ransEncode, + final RANSDecode ransDecode, + final RANSParams params){ + // helper method for Boundary Expectations test + final ByteBuffer compressed = ransEncode.compress(raw, params); + final ByteBuffer uncompressed = ransDecode.uncompress(compressed); + Assert.assertFalse(compressed.hasRemaining()); + compressed.rewind(); + Assert.assertEquals(uncompressed.limit(), rawSize); + Assert.assertEquals(uncompressed.position(), 0); + Assert.assertFalse(raw.hasRemaining()); + Assert.assertEquals(raw.limit(), rawSize); + Assert.assertEquals(compressed.position(), 0); + return compressed; + } + + private byte[] getNBytesWithValues(final int n, final BiFunction valueForIndex) { + final byte[] data = new byte[n]; + for (int i = 0; i < data.length; i++) { + data[i] = valueForIndex.apply(n, i); + } + return data; } private byte[] randomBytesFromGeometricDistribution(final int size, final double p) { @@ -185,4 +326,5 @@ private byte drawByteFromGeometricDistribution(final double probability) { final double g = Math.ceil(Math.log(1 - rand) / Math.log(1 - probability)) - 1; return (byte) g; } -} + +} \ No newline at end of file diff --git a/src/test/java/htsjdk/utils/SamtoolsTestUtils.java b/src/test/java/htsjdk/utils/SamtoolsTestUtils.java index f70674d8cd..eda5fcb46e 100644 --- a/src/test/java/htsjdk/utils/SamtoolsTestUtils.java +++ b/src/test/java/htsjdk/utils/SamtoolsTestUtils.java @@ -1,7 +1,8 @@ package htsjdk.utils; -import htsjdk.samtools.util.*; - +import htsjdk.samtools.util.FileExtensions; +import htsjdk.samtools.util.ProcessExecutor; +import htsjdk.samtools.util.RuntimeIOException; import java.io.File; import java.io.IOException; import java.nio.file.Files; @@ -14,6 +15,7 @@ public class SamtoolsTestUtils { private static final String SAMTOOLS_BINARY_ENV_VARIABLE = "HTSJDK_SAMTOOLS_BIN"; public final static String expectedSamtoolsVersion = "1.14"; + public final static String expectedHtslibVersion = "1.14"; /** * @return true if samtools is available, otherwise false @@ -47,6 +49,11 @@ public static String getSamtoolsBin() { return samtoolsPath == null ? "/usr/local/bin/samtools" : samtoolsPath; } + public static String getCRAMInteropData() { + final String samtoolsPath = System.getenv(SAMTOOLS_BINARY_ENV_VARIABLE); + return samtoolsPath == null ? "../htscodecs/tests" : "./samtools-"+expectedSamtoolsVersion+ "/htslib-"+expectedHtslibVersion+"/htscodecs/tests"; + } + /** * Execute a samtools command line if a local samtools executable is available see {@link #isSamtoolsAvailable()}. *