diff --git a/moonlight-common/.gitignore b/moonlight-common/.gitignore index 6b468b62..ec995b00 100644 --- a/moonlight-common/.gitignore +++ b/moonlight-common/.gitignore @@ -1 +1,2 @@ *.class +/bin diff --git a/moonlight-common/libs/tinyrtsp.jar b/moonlight-common/libs/tinyrtsp.jar index 8020c06f..e87fa76d 100644 Binary files a/moonlight-common/libs/tinyrtsp.jar and b/moonlight-common/libs/tinyrtsp.jar differ diff --git a/moonlight-common/src/com/limelight/nvstream/NvConnection.java b/moonlight-common/src/com/limelight/nvstream/NvConnection.java index 99796372..f03524ab 100644 --- a/moonlight-common/src/com/limelight/nvstream/NvConnection.java +++ b/moonlight-common/src/com/limelight/nvstream/NvConnection.java @@ -5,11 +5,15 @@ import java.net.InetAddress; import java.net.NetworkInterface; import java.net.SocketException; import java.net.UnknownHostException; +import java.security.NoSuchAlgorithmException; import java.util.Enumeration; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; + import org.xmlpull.v1.XmlPullParserException; import com.limelight.LimeLog; @@ -19,8 +23,10 @@ import com.limelight.nvstream.av.video.VideoDecoderRenderer; import com.limelight.nvstream.av.video.VideoStream; import com.limelight.nvstream.control.ControlStream; import com.limelight.nvstream.http.GfeHttpResponseException; +import com.limelight.nvstream.http.LimelightCryptoProvider; import com.limelight.nvstream.http.NvApp; import com.limelight.nvstream.http.NvHTTP; +import com.limelight.nvstream.http.PairingManager; import com.limelight.nvstream.input.NvController; import com.limelight.nvstream.rtsp.RtspConnection; @@ -28,6 +34,7 @@ public class NvConnection { private String host; private NvConnectionListener listener; private StreamConfiguration config; + private LimelightCryptoProvider cryptoProvider; private InetAddress hostAddr; private ControlStream controlStream; @@ -41,19 +48,38 @@ public class NvConnection { private VideoDecoderRenderer videoDecoderRenderer; private AudioRenderer audioRenderer; private String localDeviceName; + private SecretKey riKey; private ThreadPoolExecutor threadPool; - public NvConnection(String host, NvConnectionListener listener, StreamConfiguration config) + public NvConnection(String host, NvConnectionListener listener, StreamConfiguration config, LimelightCryptoProvider cryptoProvider) { this.host = host; this.listener = listener; this.config = config; + this.cryptoProvider = cryptoProvider; + + try { + // This is unique per connection + this.riKey = generateRiAesKey(); + } catch (NoSuchAlgorithmException e) { + // Should never happen + e.printStackTrace(); + } this.threadPool = new ThreadPoolExecutor(1, 1, Long.MAX_VALUE, TimeUnit.DAYS, new LinkedBlockingQueue(), new ThreadPoolExecutor.DiscardPolicy()); } + private static SecretKey generateRiAesKey() throws NoSuchAlgorithmException { + KeyGenerator keyGen = KeyGenerator.getInstance("AES"); + + // RI keys are 128 bits + keyGen.init(128); + + return keyGen.generateKey(); + } + public static String getMacAddressString() throws SocketException { Enumeration ifaceList; NetworkInterface selectedIface = null; @@ -136,24 +162,18 @@ public class NvConnection { private boolean startApp() throws XmlPullParserException, IOException { - NvHTTP h = new NvHTTP(hostAddr, getMacAddressString(), localDeviceName); + NvHTTP h = new NvHTTP(hostAddr, getMacAddressString(), localDeviceName, cryptoProvider); - if (h.getAppVersion().startsWith("1.")) { + if (h.getServerVersion().startsWith("1.")) { listener.displayMessage("Limelight now requires GeForce Experience 2.0.1 or later. Please upgrade GFE on your PC and try again."); return false; } - if (!h.getPairState()) { + if (h.getPairState() != PairingManager.PairState.PAIRED) { listener.displayMessage("Device not paired with computer"); return false; } - - int sessionId = h.getSessionId(); - if (sessionId == 0) { - listener.displayMessage("Invalid session ID"); - return false; - } - + NvApp app = h.getApp(config.getApp()); if (app == null) { listener.displayMessage("The app " + config.getApp() + " is not in GFE app list"); @@ -163,7 +183,7 @@ public class NvConnection { // If there's a game running, resume it if (h.getCurrentGame() != 0) { try { - if (h.getCurrentGame() == app.getAppId() && !h.resumeApp()) { + if (h.getCurrentGame() == app.getAppId() && !h.resumeApp(riKey)) { listener.displayMessage("Failed to resume existing session"); return false; } else if (h.getCurrentGame() != app.getAppId()) { @@ -209,7 +229,7 @@ public class NvConnection { throws IOException, XmlPullParserException { // Launch the app since it's not running int gameSessionId = h.launchApp(app.getAppId(), config.getWidth(), - config.getHeight(), config.getRefreshRate()); + config.getHeight(), config.getRefreshRate(), riKey); if (gameSessionId == 0) { listener.displayMessage("Failed to launch application"); return false; @@ -255,7 +275,7 @@ public class NvConnection { // it to the instance variable once the object is properly initialized. // This avoids the race where inputStream != null but inputStream.initialize() // has not returned yet. - NvController tempController = new NvController(hostAddr); + NvController tempController = new NvController(hostAddr, riKey); tempController.initialize(); inputStream = tempController; return true; diff --git a/moonlight-common/src/com/limelight/nvstream/av/DecodeUnit.java b/moonlight-common/src/com/limelight/nvstream/av/DecodeUnit.java index b09ffbf5..3de796ca 100644 --- a/moonlight-common/src/com/limelight/nvstream/av/DecodeUnit.java +++ b/moonlight-common/src/com/limelight/nvstream/av/DecodeUnit.java @@ -13,16 +13,16 @@ public class DecodeUnit { private int type; private List bufferList; private int dataLength; - private int flags; private int frameNumber; + private long receiveTimestamp; - public DecodeUnit(int type, List bufferList, int dataLength, int flags, int frameNumber) + public DecodeUnit(int type, List bufferList, int dataLength, int frameNumber, long receiveTimestamp) { this.type = type; this.bufferList = bufferList; this.dataLength = dataLength; - this.flags = flags; this.frameNumber = frameNumber; + this.receiveTimestamp = receiveTimestamp; } public int getType() @@ -30,9 +30,9 @@ public class DecodeUnit { return type; } - public int getFlags() + public long getReceiveTimestamp() { - return flags; + return receiveTimestamp; } public List getBufferList() diff --git a/moonlight-common/src/com/limelight/nvstream/av/video/VideoDecoderRenderer.java b/moonlight-common/src/com/limelight/nvstream/av/video/VideoDecoderRenderer.java index 335c53d2..48e5a6db 100644 --- a/moonlight-common/src/com/limelight/nvstream/av/video/VideoDecoderRenderer.java +++ b/moonlight-common/src/com/limelight/nvstream/av/video/VideoDecoderRenderer.java @@ -1,24 +1,17 @@ package com.limelight.nvstream.av.video; -import com.limelight.nvstream.av.DecodeUnit; - public interface VideoDecoderRenderer { public static final int FLAG_PREFER_QUALITY = 0x1; public static final int FLAG_FORCE_HARDWARE_DECODING = 0x2; public static final int FLAG_FORCE_SOFTWARE_DECODING = 0x4; - // SubmitDecodeUnit() is lightweight, so don't use an extra thread for decoding - public static final int CAPABILITY_DIRECT_SUBMIT = 0x1; - public int getCapabilities(); public void setup(int width, int height, int redrawRate, Object renderTarget, int drFlags); - public void start(); + public void start(VideoDepacketizer depacketizer); public void stop(); public void release(); - - public boolean submitDecodeUnit(DecodeUnit decodeUnit); } diff --git a/moonlight-common/src/com/limelight/nvstream/av/video/VideoDepacketizer.java b/moonlight-common/src/com/limelight/nvstream/av/video/VideoDepacketizer.java index 32e6fc32..9cf2ee0d 100644 --- a/moonlight-common/src/com/limelight/nvstream/av/video/VideoDepacketizer.java +++ b/moonlight-common/src/com/limelight/nvstream/av/video/VideoDepacketizer.java @@ -23,19 +23,18 @@ public class VideoDepacketizer { private int startFrameNumber = 1; private boolean waitingForNextSuccessfulFrame; private boolean gotNextFrameStart; + private long frameStartTime; // Cached objects private ByteBufferDescriptor cachedDesc = new ByteBufferDescriptor(null, 0, 0); private ConnectionStatusListener controlListener; - private VideoDecoderRenderer directSubmitDr; private static final int DU_LIMIT = 15; private LinkedBlockingQueue decodedUnits = new LinkedBlockingQueue(DU_LIMIT); - public VideoDepacketizer(VideoDecoderRenderer directSubmitDr, ConnectionStatusListener controlListener) + public VideoDepacketizer(ConnectionStatusListener controlListener) { - this.directSubmitDr = directSubmitDr; this.controlListener = controlListener; } @@ -49,30 +48,9 @@ public class VideoDepacketizer { { // This is the start of a new frame if (avcFrameDataChain != null && avcFrameDataLength != 0) { - int flags = 0; - - ByteBufferDescriptor firstBuffer = avcFrameDataChain.getFirst(); - - if (NAL.getSpecialSequenceDescriptor(firstBuffer, cachedDesc) && NAL.isAvcFrameStart(cachedDesc)) { - switch (cachedDesc.data[cachedDesc.offset+cachedDesc.length]) { - case 0x67: - case 0x68: - flags |= DecodeUnit.DU_FLAG_CODEC_CONFIG; - break; - - case 0x65: - flags |= DecodeUnit.DU_FLAG_SYNC_FRAME; - break; - } - } - // Construct the H264 decode unit - DecodeUnit du = new DecodeUnit(DecodeUnit.TYPE_H264, avcFrameDataChain, avcFrameDataLength, flags, frameNumber); - if (directSubmitDr != null) { - // Submit directly to the decoder - directSubmitDr.submitDecodeUnit(du); - } - else if (!decodedUnits.offer(du)) { + DecodeUnit du = new DecodeUnit(DecodeUnit.TYPE_H264, avcFrameDataChain, avcFrameDataLength, frameNumber, frameStartTime); + if (!decodedUnits.offer(du)) { LimeLog.warning("Video decoder is too slow! Forced to drop decode units"); // Invalidate all frames from the start of the DU queue @@ -92,7 +70,7 @@ public class VideoDepacketizer { } } - public void addInputDataSlow(VideoPacket packet, ByteBufferDescriptor location) + private void addInputDataSlow(VideoPacket packet, ByteBufferDescriptor location) { while (location.length != 0) { @@ -175,10 +153,11 @@ public class VideoDepacketizer { } } - public void addInputDataFast(VideoPacket packet, ByteBufferDescriptor location, boolean firstPacket) + private void addInputDataFast(VideoPacket packet, ByteBufferDescriptor location, boolean firstPacket) { if (firstPacket) { // Setup state for the new frame + frameStartTime = System.currentTimeMillis(); avcFrameDataChain = new LinkedList(); avcFrameDataLength = 0; } @@ -340,10 +319,15 @@ public class VideoDepacketizer { addInputData(new VideoPacket(rtpPayload)); } - public DecodeUnit getNextDecodeUnit() throws InterruptedException + public DecodeUnit takeNextDecodeUnit() throws InterruptedException { return decodedUnits.take(); } + + public DecodeUnit pollNextDecodeUnit() + { + return decodedUnits.poll(); + } } class NAL { diff --git a/moonlight-common/src/com/limelight/nvstream/av/video/VideoStream.java b/moonlight-common/src/com/limelight/nvstream/av/video/VideoStream.java index b47ee96a..61fb517d 100644 --- a/moonlight-common/src/com/limelight/nvstream/av/video/VideoStream.java +++ b/moonlight-common/src/com/limelight/nvstream/av/video/VideoStream.java @@ -13,7 +13,6 @@ import java.util.LinkedList; import com.limelight.nvstream.NvConnectionListener; import com.limelight.nvstream.StreamConfiguration; import com.limelight.nvstream.av.ByteBufferDescriptor; -import com.limelight.nvstream.av.DecodeUnit; import com.limelight.nvstream.av.RtpPacket; import com.limelight.nvstream.av.ConnectionStatusListener; @@ -140,12 +139,7 @@ public class VideoStream { decRend.setup(streamConfig.getWidth(), streamConfig.getHeight(), 60, renderTarget, drFlags); - if ((decRend.getCapabilities() & VideoDecoderRenderer.CAPABILITY_DIRECT_SUBMIT) != 0) { - depacketizer = new VideoDepacketizer(decRend, avConnListener); - } - else { - depacketizer = new VideoDepacketizer(null, avConnListener); - } + depacketizer = new VideoDepacketizer(avConnListener); } } @@ -173,44 +167,12 @@ public class VideoStream { // early packets startReceiveThread(); - // Start a decode thread if we're not doing direct submit - if ((decRend.getCapabilities() & VideoDecoderRenderer.CAPABILITY_DIRECT_SUBMIT) == 0) { - startDecoderThread(); - } - // Start the renderer - decRend.start(); + decRend.start(depacketizer); startedRendering = true; } } - private void startDecoderThread() - { - Thread t = new Thread() { - @Override - public void run() { - // Read the decode units generated from the RTP stream - while (!isInterrupted()) - { - DecodeUnit du; - - try { - du = depacketizer.getNextDecodeUnit(); - } catch (InterruptedException e) { - listener.connectionTerminated(e); - return; - } - - decRend.submitDecodeUnit(du); - } - } - }; - threads.add(t); - t.setName("Video - Decoder"); - t.setPriority(Thread.MAX_PRIORITY); - t.start(); - } - private void startReceiveThread() { // Receive thread @@ -252,6 +214,7 @@ public class VideoStream { }; threads.add(t); t.setName("Video - Receive"); + t.setPriority(Thread.MAX_PRIORITY); t.start(); } diff --git a/moonlight-common/src/com/limelight/nvstream/http/LimelightCryptoProvider.java b/moonlight-common/src/com/limelight/nvstream/http/LimelightCryptoProvider.java new file mode 100644 index 00000000..fe313cd9 --- /dev/null +++ b/moonlight-common/src/com/limelight/nvstream/http/LimelightCryptoProvider.java @@ -0,0 +1,11 @@ +package com.limelight.nvstream.http; + +import java.security.cert.X509Certificate; +import java.security.interfaces.RSAPrivateKey; + +public interface LimelightCryptoProvider { + public X509Certificate getClientCertificate(); + public RSAPrivateKey getClientPrivateKey(); + public byte[] getPemEncodedClientCertificate(); + public String encodeBase64String(byte[] data); +} diff --git a/moonlight-common/src/com/limelight/nvstream/http/NvHTTP.java b/moonlight-common/src/com/limelight/nvstream/http/NvHTTP.java index 5438e5c6..765281c0 100644 --- a/moonlight-common/src/com/limelight/nvstream/http/NvHTTP.java +++ b/moonlight-common/src/com/limelight/nvstream/http/NvHTTP.java @@ -3,13 +3,19 @@ package com.limelight.nvstream.http; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.io.Reader; +import java.io.StringReader; import java.net.Inet6Address; import java.net.InetAddress; +import java.net.MalformedURLException; import java.net.URL; import java.net.URLConnection; import java.util.LinkedList; +import java.util.Scanner; import java.util.Stack; +import javax.crypto.SecretKey; + import org.xmlpull.v1.XmlPullParser; import org.xmlpull.v1.XmlPullParserException; import org.xmlpull.v1.XmlPullParserFactory; @@ -17,14 +23,19 @@ import org.xmlpull.v1.XmlPullParserFactory; public class NvHTTP { private String uniqueId; + private PairingManager pm; + private LimelightCryptoProvider cryptoProvider; - public static final int PORT = 47989; + public static final int PORT = 47984; public static final int CONNECTION_TIMEOUT = 5000; + + private final boolean verbose = false; public String baseUrl; - - public NvHTTP(InetAddress host, String uniqueId, String deviceName) { + + public NvHTTP(InetAddress host, String uniqueId, String deviceName, LimelightCryptoProvider cryptoProvider) { this.uniqueId = uniqueId; + this.cryptoProvider = cryptoProvider; String safeAddress; if (host instanceof Inet6Address) { @@ -35,16 +46,16 @@ public class NvHTTP { safeAddress = host.getHostAddress(); } - this.baseUrl = "http://" + safeAddress + ":" + PORT; + this.baseUrl = "https://" + safeAddress + ":" + PORT; + this.pm = new PairingManager(this, cryptoProvider); } - - private String getXmlString(InputStream in, String tagname) - throws XmlPullParserException, IOException { + + static String getXmlString(Reader r, String tagname) throws XmlPullParserException, IOException { XmlPullParserFactory factory = XmlPullParserFactory.newInstance(); factory.setNamespaceAware(true); XmlPullParser xpp = factory.newPullParser(); - xpp.setInput(new InputStreamReader(in)); + xpp.setInput(r); int eventType = xpp.getEventType(); Stack currentTag = new Stack(); @@ -70,8 +81,16 @@ public class NvHTTP { return null; } + + static String getXmlString(String str, String tagname) throws XmlPullParserException, IOException { + return getXmlString(new StringReader(str), tagname); + } - private void verifyResponseStatus(XmlPullParser xpp) throws GfeHttpResponseException { + static String getXmlString(InputStream in, String tagname) throws XmlPullParserException, IOException { + return getXmlString(new InputStreamReader(in), tagname); + } + + private static void verifyResponseStatus(XmlPullParser xpp) throws GfeHttpResponseException { int statusCode = Integer.parseInt(xpp.getAttributeValue(XmlPullParser.NO_NAMESPACE, "status_code")); if (statusCode != 200) { throw new GfeHttpResponseException(statusCode, xpp.getAttributeValue(XmlPullParser.NO_NAMESPACE, "status_message")); @@ -80,31 +99,43 @@ public class NvHTTP { private InputStream openHttpConnection(String url) throws IOException { URLConnection conn = new URL(url).openConnection(); + if (verbose) { + System.out.println(url); + } conn.setConnectTimeout(CONNECTION_TIMEOUT); - conn.setDefaultUseCaches(false); + conn.setUseCaches(false); conn.connect(); return conn.getInputStream(); } + + String openHttpConnectionToString(String url) throws MalformedURLException, IOException { + Scanner s = new Scanner(openHttpConnection(url)); + + String str = ""; + while (s.hasNext()) { + str += s.next() + " "; + } + + s.close(); + + if (verbose) { + System.out.println(str); + } + + return str; + } - public String getAppVersion() throws XmlPullParserException, IOException { - InputStream in = openHttpConnection(baseUrl + "/appversion"); + public String getServerVersion() throws XmlPullParserException, IOException { + InputStream in = openHttpConnection(baseUrl + "/serverinfo?uniqueid=" + uniqueId); return getXmlString(in, "appversion"); } - public boolean getPairState() throws IOException, XmlPullParserException { - InputStream in = openHttpConnection(baseUrl + "/pairstate?uniqueid=" + uniqueId); - String paired = getXmlString(in, "paired"); - return Integer.valueOf(paired) != 0; - } - - public int getSessionId() throws IOException, XmlPullParserException { - InputStream in = openHttpConnection(baseUrl + "/pair?uniqueid=" + uniqueId); - String sessionId = getXmlString(in, "sessionid"); - return Integer.parseInt(sessionId); + public PairingManager.PairState getPairState() throws IOException, XmlPullParserException { + return pm.getPairState(uniqueId); } public int getCurrentGame() throws IOException, XmlPullParserException { - InputStream in = openHttpConnection(baseUrl + "/serverinfo"); + InputStream in = openHttpConnection(baseUrl + "/serverinfo?uniqueid=" + uniqueId); String game = getXmlString(in, "currentgame"); return Integer.parseInt(game); } @@ -120,6 +151,10 @@ public class NvHTTP { return null; } + public PairingManager.PairState pair(String pin) throws Exception { + return pm.pair(uniqueId, pin); + } + public LinkedList getAppList() throws GfeHttpResponseException, IOException, XmlPullParserException { InputStream in = openHttpConnection(baseUrl + "/applist?uniqueid=" + uniqueId); XmlPullParserFactory factory = XmlPullParserFactory.newInstance(); @@ -161,18 +196,19 @@ public class NvHTTP { return appList; } - // Returns gameSession XML attribute - public int launchApp(int appId, int width, int height, int refreshRate) throws IOException, XmlPullParserException { + public int launchApp(int appId, int width, int height, int refreshRate, SecretKey inputKey) throws IOException, XmlPullParserException { InputStream in = openHttpConnection(baseUrl + "/launch?uniqueid=" + uniqueId + "&appid=" + appId + - "&mode=" + width + "x" + height + "x" + refreshRate); + "&mode=" + width + "x" + height + "x" + refreshRate + + "&additionalStates=1&sops=1&rikey="+cryptoProvider.encodeBase64String(inputKey.getEncoded())); String gameSession = getXmlString(in, "gamesession"); return Integer.parseInt(gameSession); } - public boolean resumeApp() throws IOException, XmlPullParserException { - InputStream in = openHttpConnection(baseUrl + "/resume?uniqueid=" + uniqueId); + public boolean resumeApp(SecretKey inputKey) throws IOException, XmlPullParserException { + InputStream in = openHttpConnection(baseUrl + "/resume?uniqueid=" + uniqueId + + "&rikey="+cryptoProvider.encodeBase64String(inputKey.getEncoded())); String resume = getXmlString(in, "resume"); return Integer.parseInt(resume) != 0; } diff --git a/moonlight-common/src/com/limelight/nvstream/http/PairingManager.java b/moonlight-common/src/com/limelight/nvstream/http/PairingManager.java new file mode 100644 index 00000000..60a2e8e3 --- /dev/null +++ b/moonlight-common/src/com/limelight/nvstream/http/PairingManager.java @@ -0,0 +1,326 @@ +package com.limelight.nvstream.http; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.SecretKey; +import javax.crypto.ShortBufferException; +import javax.crypto.spec.SecretKeySpec; +import javax.net.ssl.*; + +import org.xmlpull.v1.XmlPullParserException; + +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.io.*; +import java.net.MalformedURLException; +import java.net.Socket; +import java.security.*; +import java.security.cert.*; +import java.util.Arrays; +import java.util.Random; + +public class PairingManager { + + private NvHTTP http; + + private PrivateKey pk; + private X509Certificate cert; + private SecretKey aesKey; + private byte[] pemCertBytes; + + public enum PairState { + NOT_PAIRED, + PAIRED, + PIN_WRONG, + FAILED + } + + public PairingManager(NvHTTP http, LimelightCryptoProvider cryptoProvider) { + this.http = http; + this.cert = cryptoProvider.getClientCertificate(); + this.pemCertBytes = cryptoProvider.getPemEncodedClientCertificate(); + this.pk = cryptoProvider.getClientPrivateKey(); + + // Update the trust manager and key manager to use our certificate and PK + installSslKeysAndTrust(); + } + + private void installSslKeysAndTrust() { + // Create a trust manager that does not validate certificate chains + TrustManager[] trustAllCerts = new TrustManager[] { + new X509TrustManager() { + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + public void checkClientTrusted(X509Certificate[] certs, String authType) {} + public void checkServerTrusted(X509Certificate[] certs, String authType) {} + }}; + + KeyManager[] ourKeyman = new KeyManager[] { + new X509KeyManager() { + public String chooseClientAlias(String[] keyTypes, + Principal[] issuers, Socket socket) { + return "Limelight-RSA"; + } + + public String chooseServerAlias(String keyType, Principal[] issuers, + Socket socket) { + return null; + } + + public X509Certificate[] getCertificateChain(String alias) { + return new X509Certificate[] {cert}; + } + + public String[] getClientAliases(String keyType, Principal[] issuers) { + return null; + } + + public PrivateKey getPrivateKey(String alias) { + return pk; + } + + public String[] getServerAliases(String keyType, Principal[] issuers) { + return null; + } + } + }; + + // Ignore differences between given hostname and certificate hostname + HostnameVerifier hv = new HostnameVerifier() { + public boolean verify(String hostname, SSLSession session) { return true; } + }; + + // Install the all-trusting trust manager + try { + SSLContext sc = SSLContext.getInstance("SSL"); + sc.init(ourKeyman, trustAllCerts, new SecureRandom()); + HttpsURLConnection.setDefaultSSLSocketFactory(sc.getSocketFactory()); + HttpsURLConnection.setDefaultHostnameVerifier(hv); + } catch (Exception e) { + e.printStackTrace(); + } + } + + final private static char[] hexArray = "0123456789ABCDEF".toCharArray(); + private static String bytesToHex(byte[] bytes) { + char[] hexChars = new char[bytes.length * 2]; + for ( int j = 0; j < bytes.length; j++ ) { + int v = bytes[j] & 0xFF; + hexChars[j * 2] = hexArray[v >>> 4]; + hexChars[j * 2 + 1] = hexArray[v & 0x0F]; + } + return new String(hexChars); + } + + private static byte[] hexToBytes(String s) { + int len = s.length(); + byte[] data = new byte[len / 2]; + for (int i = 0; i < len; i += 2) { + data[i / 2] = (byte) ((Character.digit(s.charAt(i), 16) << 4) + + Character.digit(s.charAt(i+1), 16)); + } + return data; + } + + private X509Certificate extractPlainCert(String text) throws XmlPullParserException, IOException, CertificateException + { + String certText = NvHTTP.getXmlString(text, "plaincert"); + byte[] certBytes = hexToBytes(certText); + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + return (X509Certificate)cf.generateCertificate(new ByteArrayInputStream(certBytes)); + } + + private byte[] generateRandomBytes(int length) + { + byte[] rand = new byte[length]; + new SecureRandom().nextBytes(rand); + return rand; + } + + private static byte[] saltPin(byte[] salt, String pin) throws UnsupportedEncodingException { + byte[] saltedPin = new byte[salt.length + pin.length()]; + System.arraycopy(salt, 0, saltedPin, 0, salt.length); + System.arraycopy(pin.getBytes("UTF-8"), 0, saltedPin, salt.length, pin.length()); + return saltedPin; + } + + private static byte[] toSHA1Bytes(byte[] data) { + try { + MessageDigest md = MessageDigest.getInstance("SHA-1"); + return md.digest(data); + } + catch (NoSuchAlgorithmException e) { + // Shouldn't ever happen + e.printStackTrace(); + return null; + } + } + + private static boolean verifySignature(byte[] data, byte[] signature, Certificate cert) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException { + Signature sig = Signature.getInstance("SHA256withRSA"); + sig.initVerify(cert.getPublicKey()); + sig.update(data); + return sig.verify(signature); + } + + private static byte[] signData(byte[] data, PrivateKey key) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException { + Signature sig = Signature.getInstance("SHA256withRSA"); + sig.initSign(key); + sig.update(data); + byte[] signature = new byte[256]; + sig.sign(signature, 0, signature.length); + return signature; + } + + private static byte[] decryptAes(byte[] encryptedData, SecretKey secretKey) throws NoSuchAlgorithmException, SignatureException, + InvalidKeyException, ShortBufferException, IllegalBlockSizeException, BadPaddingException, NoSuchPaddingException { + Cipher cipher = Cipher.getInstance("AES/ECB/NoPadding"); + + int blockRoundedSize = ((encryptedData.length + 15) / 16) * 16; + byte[] blockRoundedEncrypted = Arrays.copyOf(encryptedData, blockRoundedSize); + byte[] fullDecrypted = new byte[blockRoundedSize]; + + cipher.init(Cipher.DECRYPT_MODE, secretKey); + cipher.doFinal(blockRoundedEncrypted, 0, + blockRoundedSize, fullDecrypted); + return fullDecrypted; + } + + private static byte[] encryptAes(byte[] data, SecretKey secretKey) throws NoSuchAlgorithmException, SignatureException, + InvalidKeyException, ShortBufferException, IllegalBlockSizeException, BadPaddingException, NoSuchPaddingException { + Cipher cipher = Cipher.getInstance("AES/ECB/NoPadding"); + + int blockRoundedSize = ((data.length + 15) / 16) * 16; + byte[] blockRoundedData = Arrays.copyOf(data, blockRoundedSize); + + cipher.init(Cipher.ENCRYPT_MODE, secretKey); + return cipher.doFinal(blockRoundedData); + } + + private static SecretKey generateAesKey(byte[] keyData) { + byte[] aesTruncated = Arrays.copyOf(toSHA1Bytes(keyData), 16); + return new SecretKeySpec(aesTruncated, "AES"); + } + + private static byte[] concatBytes(byte[] a, byte[] b) { + byte[] c = new byte[a.length + b.length]; + System.arraycopy(a, 0, c, 0, a.length); + System.arraycopy(b, 0, c, a.length, b.length); + return c; + } + + public static String generatePinString() { + Random r = new Random(); + return String.format("%d%d%d%d", + r.nextInt(10), r.nextInt(10), + r.nextInt(10), r.nextInt(10)); + } + + public PairState getPairState(String uniqueId) throws MalformedURLException, IOException, XmlPullParserException { + String serverInfo = http.openHttpConnectionToString(http.baseUrl + "/serverinfo?uniqueid="+uniqueId); + if (!NvHTTP.getXmlString(serverInfo, "PairStatus").equals("1")) { + return PairState.NOT_PAIRED; + } + + String pairChallenge = http.openHttpConnectionToString(http.baseUrl + "/pair?uniqueid="+uniqueId+"&devicename=roth&updateState=1&phrase=pairchallenge"); + if (NvHTTP.getXmlString(pairChallenge, "paired").equals("1")) { + return PairState.PAIRED; + } + else { + return PairState.NOT_PAIRED; + } + } + + public PairState pair(String uniqueId, String pin) throws MalformedURLException, IOException, XmlPullParserException, CertificateException, InvalidKeyException, NoSuchAlgorithmException, SignatureException, ShortBufferException, IllegalBlockSizeException, BadPaddingException, NoSuchPaddingException { + // Generate a salt for hashing the PIN + byte[] salt = generateRandomBytes(16); + + // Combine the salt and pin, then create an AES key from them + byte[] saltAndPin = saltPin(salt, pin); + aesKey = generateAesKey(saltAndPin); + + // Send the salt and get the server cert + String getCert = http.openHttpConnectionToString(http.baseUrl + + "/pair?uniqueid="+uniqueId+"&devicename=roth&updateState=1&phrase=getservercert&salt="+bytesToHex(salt)+"&clientcert="+bytesToHex(pemCertBytes)); + if (!NvHTTP.getXmlString(getCert, "paired").equals("1")) { + http.openHttpConnectionToString(http.baseUrl + "/unpair?uniqueid="+uniqueId); + return PairState.FAILED; + } + X509Certificate serverCert = extractPlainCert(getCert); + + // Generate a random challenge and encrypt it with our AES key + byte[] randomChallenge = generateRandomBytes(16); + byte[] encryptedChallenge = encryptAes(randomChallenge, aesKey); + + // Send the encrypted challenge to the server + String challengeResp = http.openHttpConnectionToString(http.baseUrl + + "/pair?uniqueid="+uniqueId+"&devicename=roth&updateState=1&clientchallenge="+bytesToHex(encryptedChallenge)); + if (!NvHTTP.getXmlString(challengeResp, "paired").equals("1")) { + http.openHttpConnectionToString(http.baseUrl + "/unpair?uniqueid="+uniqueId); + return PairState.FAILED; + } + + // Decode the server's response and subsequent challenge + byte[] encServerChallengeResponse = hexToBytes(NvHTTP.getXmlString(challengeResp, "challengeresponse")); + byte[] decServerChallengeResponse = decryptAes(encServerChallengeResponse, aesKey); + + byte[] serverResponse = Arrays.copyOfRange(decServerChallengeResponse, 0, 20); + byte[] serverChallenge = Arrays.copyOfRange(decServerChallengeResponse, 20, 36); + + // Using another 16 bytes secret, compute a challenge response hash using the secret, our cert sig, and the challenge + byte[] clientSecret = generateRandomBytes(16); + byte[] challengeRespHash = toSHA1Bytes(concatBytes(concatBytes(serverChallenge, cert.getSignature()), clientSecret)); + byte[] challengeRespEncrypted = encryptAes(challengeRespHash, aesKey); + String secretResp = http.openHttpConnectionToString(http.baseUrl + + "/pair?uniqueid="+uniqueId+"&devicename=roth&updateState=1&serverchallengeresp="+bytesToHex(challengeRespEncrypted)); + if (!NvHTTP.getXmlString(secretResp, "paired").equals("1")) { + http.openHttpConnectionToString(http.baseUrl + "/unpair?uniqueid="+uniqueId); + return PairState.FAILED; + } + + // Get the server's signed secret + byte[] serverSecretResp = hexToBytes(NvHTTP.getXmlString(secretResp, "pairingsecret")); + byte[] serverSecret = Arrays.copyOfRange(serverSecretResp, 0, 16); + byte[] serverSignature = Arrays.copyOfRange(serverSecretResp, 16, 272); + + // Ensure the authenticity of the data + if (!verifySignature(serverSecret, serverSignature, serverCert)) { + // Cancel the pairing process + http.openHttpConnectionToString(http.baseUrl + "/unpair?uniqueid="+uniqueId); + + // Looks like a MITM + return PairState.FAILED; + } + + // Ensure the server challenge matched what we expected (aka the PIN was correct) + byte[] serverChallengeRespHash = toSHA1Bytes(concatBytes(concatBytes(randomChallenge, serverCert.getSignature()), serverSecret)); + if (!Arrays.equals(serverChallengeRespHash, serverResponse)) { + // Cancel the pairing process + http.openHttpConnectionToString(http.baseUrl + "/unpair?uniqueid="+uniqueId); + + // Probably got the wrong PIN + return PairState.PIN_WRONG; + } + + // Send the server our signed secret + byte[] clientPairingSecret = concatBytes(clientSecret, signData(clientSecret, pk)); + String clientSecretResp = http.openHttpConnectionToString(http.baseUrl + + "/pair?uniqueid="+uniqueId+"&devicename=roth&updateState=1&clientpairingsecret="+bytesToHex(clientPairingSecret)); + if (!NvHTTP.getXmlString(clientSecretResp, "paired").equals("1")) { + http.openHttpConnectionToString(http.baseUrl + "/unpair?uniqueid="+uniqueId); + return PairState.FAILED; + } + + // Do the initial challenge (seems neccessary for us to show as paired) + String pairChallenge = http.openHttpConnectionToString(http.baseUrl + "/pair?uniqueid="+uniqueId+"&devicename=roth&updateState=1&phrase=pairchallenge"); + if (!NvHTTP.getXmlString(pairChallenge, "paired").equals("1")) { + http.openHttpConnectionToString(http.baseUrl + "/unpair?uniqueid="+uniqueId); + return PairState.FAILED; + } + + return PairState.PAIRED; + } +} diff --git a/moonlight-common/src/com/limelight/nvstream/input/NvController.java b/moonlight-common/src/com/limelight/nvstream/input/NvController.java index 28655a6e..2848ce04 100644 --- a/moonlight-common/src/com/limelight/nvstream/input/NvController.java +++ b/moonlight-common/src/com/limelight/nvstream/input/NvController.java @@ -5,6 +5,15 @@ import java.io.OutputStream; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Socket; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; + +import javax.crypto.Cipher; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.SecretKey; +import javax.crypto.spec.IvParameterSpec; public class NvController { @@ -15,10 +24,27 @@ public class NvController { private InetAddress host; private Socket s; private OutputStream out; + private Cipher riCipher; + - public NvController(InetAddress host) + private final static byte[] ENCRYPTED_HEADER = new byte[] {0x00, 0x00, 0x00, 0x20}; + + public NvController(InetAddress host, SecretKey riKey) { this.host = host; + try { + // This cipher is guaranteed to be supported + this.riCipher = Cipher.getInstance("AES/CBC/NoPadding"); + this.riCipher.init(Cipher.ENCRYPT_MODE, riKey, new IvParameterSpec(new byte[16])); + } catch (NoSuchAlgorithmException e) { + e.printStackTrace(); + } catch (NoSuchPaddingException e) { + e.printStackTrace(); + } catch (InvalidKeyException e) { + e.printStackTrace(); + } catch (InvalidAlgorithmParameterException e) { + e.printStackTrace(); + } } public void initialize() throws IOException @@ -36,36 +62,51 @@ public class NvController { } catch (IOException e) {} } + private byte[] encryptAesInputData(byte[] data) throws Exception { + // Input data is rounded to units of 32 bytes + byte[] blockRoundedData = Arrays.copyOf(data, 32); + return riCipher.update(blockRoundedData); + } + + private void sendPacket(InputPacket packet) throws IOException { + out.write(ENCRYPTED_HEADER); + byte[] encryptedInput; + try { + encryptedInput = encryptAesInputData(packet.toWire()); + } catch (Exception e) { + // Should never happen + e.printStackTrace(); + return; + } + out.write(encryptedInput); + out.flush(); + } + public void sendControllerInput(short buttonFlags, byte leftTrigger, byte rightTrigger, short leftStickX, short leftStickY, short rightStickX, short rightStickY) throws IOException { - out.write(new ControllerPacket(buttonFlags, leftTrigger, + sendPacket(new ControllerPacket(buttonFlags, leftTrigger, rightTrigger, leftStickX, leftStickY, - rightStickX, rightStickY).toWire()); - out.flush(); + rightStickX, rightStickY)); } public void sendMouseButtonDown(byte mouseButton) throws IOException { - out.write(new MouseButtonPacket(true, mouseButton).toWire()); - out.flush(); + sendPacket(new MouseButtonPacket(true, mouseButton)); } public void sendMouseButtonUp(byte mouseButton) throws IOException { - out.write(new MouseButtonPacket(false, mouseButton).toWire()); - out.flush(); + sendPacket(new MouseButtonPacket(false, mouseButton)); } public void sendMouseMove(short deltaX, short deltaY) throws IOException { - out.write(new MouseMovePacket(deltaX, deltaY).toWire()); - out.flush(); + sendPacket(new MouseMovePacket(deltaX, deltaY)); } public void sendKeyboardInput(short keyMap, byte keyDirection, byte modifier) throws IOException { - out.write(new KeyboardPacket(keyMap, keyDirection, modifier).toWire()); - out.flush(); + sendPacket(new KeyboardPacket(keyMap, keyDirection, modifier)); } }