diff --git a/moonlight-common/src/com/limelight/nvstream/control/ControlStream.java b/moonlight-common/src/com/limelight/nvstream/control/ControlStream.java index ffe33c50..45609566 100644 --- a/moonlight-common/src/com/limelight/nvstream/control/ControlStream.java +++ b/moonlight-common/src/com/limelight/nvstream/control/ControlStream.java @@ -28,6 +28,7 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende private static final int IDX_START_B = 1; private static final int IDX_INVALIDATE_REF_FRAMES = 2; private static final int IDX_LOSS_STATS = 3; + private static final int IDX_INPUT_DATA = 5; private static final short packetTypesGen3[] = { 0x140b, // Start A @@ -35,6 +36,7 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende 0x1404, // Invalidate reference frames 0x140c, // Loss Stats 0x1417, // Frame Stats (unused) + -1, // Input data (unused) }; private static final short packetTypesGen4[] = { 0x0606, // Request IDR frame @@ -42,6 +44,7 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende 0x0604, // Invalidate reference frames 0x060a, // Loss Stats 0x0611, // Frame Stats (unused) + -1, // Input data (unused) }; private static final short packetTypesGen5[] = { 0x0305, // Start A @@ -49,6 +52,15 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende 0x0301, // Invalidate reference frames 0x0201, // Loss Stats 0x0204, // Frame Stats (unused) + 0x0207, // Input data + }; + private static final short packetTypesGen7[] = { + 0x0305, // Start A + 0x0307, // Start B + 0x0301, // Invalidate reference frames + 0x0201, // Loss Stats + 0x0204, // Frame Stats (unused) + 0x0206, // Input data }; private static final short payloadLengthsGen3[] = { @@ -57,6 +69,7 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende 24, // Invalidate reference frames 32, // Loss Stats 64, // Frame Stats + -1, // Input Data }; private static final short payloadLengthsGen4[] = { -1, // Request IDR frame @@ -64,6 +77,7 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende 24, // Invalidate reference frames 32, // Loss Stats 64, // Frame Stats + -1, // Input Data }; private static final short payloadLengthsGen5[] = { -1, // Start A @@ -71,6 +85,15 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende 24, // Invalidate reference frames 32, // Loss Stats 80, // Frame Stats + -1, // Input Data + }; + private static final short payloadLengthsGen7[] = { + -1, // Start A + 16, // Start B + 24, // Invalidate reference frames + 32, // Loss Stats + 80, // Frame Stats + -1, // Input Data }; private static final byte[] precontructedPayloadsGen3[] = { @@ -79,6 +102,7 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende null, // Invalidate reference frames null, // Loss Stats null, // Frame Stats + null, // Input Data }; private static final byte[] precontructedPayloadsGen4[] = { new byte[]{0, 0}, // Request IDR frame @@ -86,6 +110,7 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende null, // Invalidate reference frames null, // Loss Stats null, // Frame Stats + null, // Input Data }; private static final byte[] precontructedPayloadsGen5[] = { new byte[]{0, 0}, // Start A @@ -93,6 +118,15 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende null, // Invalidate reference frames null, // Loss Stats null, // Frame Stats + null, // Input Data + }; + private static final byte[] precontructedPayloadsGen7[] = { + new byte[]{0, 0}, // Start A + null, // Start B + null, // Invalidate reference frames + null, // Loss Stats + null, // Frame Stats + null, // Input Data }; public static final int LOSS_REPORT_INTERVAL_MS = 50; @@ -154,11 +188,16 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende preconstructedPayloads = precontructedPayloadsGen4; break; case ConnectionContext.SERVER_GENERATION_5: - default: packetTypes = packetTypesGen5; payloadLengths = payloadLengthsGen5; preconstructedPayloads = precontructedPayloadsGen5; break; + case ConnectionContext.SERVER_GENERATION_7: + default: + packetTypes = packetTypesGen7; + payloadLengths = payloadLengthsGen7; + preconstructedPayloads = precontructedPayloadsGen7; + break; } if (context.videoDecoderRenderer != null) { @@ -225,7 +264,7 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende } public void sendInputPacket(byte[] data, short length) throws IOException { - sendPacket(new NvCtlPacket((short) 0x0207, length, data)); + sendPacket(new NvCtlPacket(packetTypes[IDX_INPUT_DATA], length, data)); } public void abort() @@ -565,7 +604,7 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende serializationBuffer.limit(serializationBuffer.capacity()); serializationBuffer.putShort(type); serializationBuffer.putShort(paylen); - serializationBuffer.put(payload); + serializationBuffer.put(payload, 0, paylen); out.write(serializationBuffer.array(), 0, serializationBuffer.position()); } @@ -578,7 +617,7 @@ public class ControlStream implements ConnectionStatusListener, InputPacketSende serializationBuffer.rewind(); serializationBuffer.limit(serializationBuffer.capacity()); serializationBuffer.putShort(type); - serializationBuffer.put(payload); + serializationBuffer.put(payload, 0, paylen); serializationBuffer.limit(serializationBuffer.position()); conn.writePacket(serializationBuffer); diff --git a/moonlight-common/src/com/limelight/nvstream/input/ControllerStream.java b/moonlight-common/src/com/limelight/nvstream/input/ControllerStream.java index ea506659..ddbb6498 100644 --- a/moonlight-common/src/com/limelight/nvstream/input/ControllerStream.java +++ b/moonlight-common/src/com/limelight/nvstream/input/ControllerStream.java @@ -12,8 +12,12 @@ import java.security.NoSuchAlgorithmException; import java.util.Iterator; import java.util.concurrent.LinkedBlockingQueue; +import javax.crypto.BadPaddingException; import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; +import javax.crypto.SecretKey; +import javax.crypto.ShortBufferException; import javax.crypto.spec.IvParameterSpec; import com.limelight.nvstream.ConnectionContext; @@ -34,7 +38,7 @@ public class ControllerStream { // Used on Gen 5+ servers private InputPacketSender controlSender; - private Cipher riCipher; + private InputCipher cipher; private Thread inputThread; private LinkedBlockingQueue inputQueue = new LinkedBlockingQueue(); @@ -45,23 +49,19 @@ public class ControllerStream { public ControllerStream(ConnectionContext context) { this.context = context; - try { - // This cipher is guaranteed to be supported - this.riCipher = Cipher.getInstance("AES/CBC/NoPadding"); - - ByteBuffer bb = ByteBuffer.allocate(16); - bb.putInt(context.riKeyId); - - this.riCipher.init(Cipher.ENCRYPT_MODE, context.riKey, new IvParameterSpec(bb.array())); - } catch (NoSuchAlgorithmException e) { - e.printStackTrace(); - } catch (NoSuchPaddingException e) { - e.printStackTrace(); - } catch (InvalidKeyException e) { - e.printStackTrace(); - } catch (InvalidAlgorithmParameterException e) { - e.printStackTrace(); + + if (context.serverGeneration >= ConnectionContext.SERVER_GENERATION_7) { + // Newer GFE versions use AES GCM + cipher = new AesGcmCipher(); } + else { + // Older versions used AES CBC + cipher = new AesCbcCipher(); + } + + ByteBuffer bb = ByteBuffer.allocate(16); + bb.putInt(context.riKeyId); + cipher.initialize(context.riKey, bb.array()); } public void initialize(InputPacketSender controlSender) throws IOException @@ -217,47 +217,19 @@ public class ControllerStream { } } - private static int getPaddedSize(int length) { - return ((length + 15) / 16) * 16; - } - - private static int inPlacePadData(byte[] data, int length) { - // This implements the PKCS7 padding algorithm - - if ((length % 16) == 0) { - // Already a multiple of 16 - return length; - } - - int paddedLength = getPaddedSize(length); - byte paddingByte = (byte)(16 - (length % 16)); - - for (int i = length; i < paddedLength; i++) { - data[i] = paddingByte; - } - - return paddedLength; - } - - private int encryptAesInputData(byte[] inputData, int inputLength, byte[] outputData, int outputOffset) throws Exception { - int encryptedLength = inPlacePadData(inputData, inputLength); - riCipher.update(inputData, 0, encryptedLength, outputData, outputOffset); - return encryptedLength; - } - private void sendPacket(InputPacket packet) throws IOException { // Store the packet in wire form in the byte buffer packet.toWire(stagingBuffer); int packetLen = packet.getPacketLength(); - // Pad to 16 byte chunks - int paddedLength = getPaddedSize(packetLen); + // Get final encrypted size of this block + int paddedLength = cipher.getEncryptedSize(packetLen); // Allocate a byte buffer to represent the final packet sendBuffer.rewind(); sendBuffer.putInt(paddedLength); try { - encryptAesInputData(stagingBuffer.array(), packetLen, sendBuffer.array(), 4); + cipher.encrypt(stagingBuffer.array(), packetLen, sendBuffer.array(), 4); } catch (Exception e) { // Should never happen e.printStackTrace(); @@ -339,4 +311,107 @@ public class ControllerStream { { queuePacket(new MouseScrollPacket(context, scrollClicks)); } + + private static interface InputCipher { + public void initialize(SecretKey key, byte[] iv); + public int getEncryptedSize(int plaintextSize); + public void encrypt(byte[] inputData, int inputLength, byte[] outputData, int outputOffset); + } + + private static class AesCbcCipher implements InputCipher { + private Cipher cipher; + + public void initialize(SecretKey key, byte[] iv) { + try { + cipher = Cipher.getInstance("AES/CBC/NoPadding"); + cipher.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv)); + } catch (NoSuchAlgorithmException e) { + e.printStackTrace(); + } catch (NoSuchPaddingException e) { + e.printStackTrace(); + } catch (InvalidKeyException e) { + e.printStackTrace(); + } catch (InvalidAlgorithmParameterException e) { + e.printStackTrace(); + } + } + + public int getEncryptedSize(int plaintextSize) { + // CBC requires padding to the next multiple of 16 + return ((plaintextSize + 15) / 16) * 16; + } + + private int inPlacePadData(byte[] data, int length) { + // This implements the PKCS7 padding algorithm + + if ((length % 16) == 0) { + // Already a multiple of 16 + return length; + } + + int paddedLength = getEncryptedSize(length); + byte paddingByte = (byte)(16 - (length % 16)); + + for (int i = length; i < paddedLength; i++) { + data[i] = paddingByte; + } + + return paddedLength; + } + + public void encrypt(byte[] inputData, int inputLength, byte[] outputData, int outputOffset) { + int encryptedLength = inPlacePadData(inputData, inputLength); + try { + cipher.update(inputData, 0, encryptedLength, outputData, outputOffset); + } catch (ShortBufferException e) { + e.printStackTrace(); + } + } + } + + private static class AesGcmCipher implements InputCipher { + private SecretKey key; + private byte[] iv; + + public int getEncryptedSize(int plaintextSize) { + // GCM uses no padding + 16 bytes tag for message authentication + return plaintextSize + 16; + } + + @Override + public void initialize(SecretKey key, byte[] iv) { + this.key = key; + this.iv = iv; + } + + @Override + public void encrypt(byte[] inputData, int inputLength, byte[] outputData, int outputOffset) { + // Reconstructing the cipher on every invocation really sucks but we have to do it + // because of the way NVIDIA is using GCM where each message is tagged. Java doesn't + // have an easy way that I know of to get a tag out mid-stream. + Cipher cipher; + try { + cipher = Cipher.getInstance("AES/GCM/NoPadding"); + cipher.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv)); + + // This is also non-ideal. Java gives us but we want to send + // so we'll take the output and arraycopy it into the right spot in the output buffer + byte[] rawCipherOut = cipher.doFinal(inputData, 0, inputLength); + System.arraycopy(rawCipherOut, inputLength, outputData, outputOffset, 16); + System.arraycopy(rawCipherOut, 0, outputData, outputOffset + 16, inputLength); + } catch (NoSuchAlgorithmException e) { + e.printStackTrace(); + } catch (NoSuchPaddingException e) { + e.printStackTrace(); + } catch (InvalidKeyException e) { + e.printStackTrace(); + } catch (InvalidAlgorithmParameterException e) { + e.printStackTrace(); + } catch (IllegalBlockSizeException e) { + e.printStackTrace(); + } catch (BadPaddingException e) { + e.printStackTrace(); + } + } + } }