diff --git a/mod.rs b/mod.rs index bcceddb..55ffc46 100644 --- a/mod.rs +++ b/mod.rs @@ -29,12 +29,14 @@ pub fn start() { 0, Default::default(), Default::default(), + "", STOP.clone(), )); }); std::thread::spawn(|| { allow_err!(relay_server::start( relay_server::DEFAULT_PORT, + "", STOP.clone() )); }); diff --git a/src/hbbr.rs b/src/hbbr.rs index 4225f91..e65b6ce 100644 --- a/src/hbbr.rs +++ b/src/hbbr.rs @@ -17,6 +17,6 @@ fn main() -> ResultType<()> { .args_from_usage(&args) .get_matches(); let stop: Arc> = Default::default(); - start(matches.value_of("port").unwrap_or(DEFAULT_PORT), stop)?; + start(matches.value_of("port").unwrap_or(DEFAULT_PORT), "", stop)?; Ok(()) } diff --git a/src/main.rs b/src/main.rs index 8978a55..ba69e01 100644 --- a/src/main.rs +++ b/src/main.rs @@ -83,6 +83,7 @@ fn main() -> ResultType<()> { serial, rendezvous_servers, get_arg("software-url", ""), + "", stop, )?; Ok(()) diff --git a/src/relay_server.rs b/src/relay_server.rs index a95f16c..decc128 100644 --- a/src/relay_server.rs +++ b/src/relay_server.rs @@ -23,7 +23,7 @@ lazy_static::lazy_static! { pub const DEFAULT_PORT: &'static str = "21117"; #[tokio::main(basic_scheduler)] -pub async fn start(port: &str, stop: Arc>) -> ResultType<()> { +pub async fn start(port: &str, license: &str, stop: Arc>) -> ResultType<()> { let addr = format!("0.0.0.0:{}", port); log::info!("Listening on {}", addr); let mut timer = interval(Duration::from_millis(300)); @@ -31,8 +31,9 @@ pub async fn start(port: &str, stop: Arc>) -> ResultType<()> { loop { tokio::select! { Ok((stream, addr)) = listener.accept() => { + let license = license.to_owned(); tokio::spawn(async move { - make_pair(FramedStream::from(stream), addr).await.ok(); + make_pair(FramedStream::from(stream), addr, &license).await.ok(); }); } _ = timer.tick() => { @@ -46,11 +47,14 @@ pub async fn start(port: &str, stop: Arc>) -> ResultType<()> { Ok(()) } -async fn make_pair(stream: FramedStream, addr: SocketAddr) -> ResultType<()> { +async fn make_pair(stream: FramedStream, addr: SocketAddr, license: &str) -> ResultType<()> { let mut stream = stream; if let Some(Ok(bytes)) = stream.next_timeout(30_000).await { if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { if let Some(rendezvous_message::Union::request_relay(rf)) = msg_in.union { + if !license.is_empty() && rf.licence_key != license { + return Ok(()); + } if !rf.uuid.is_empty() { let peer = PEERS.lock().unwrap().remove(&rf.uuid); if let Some(peer) = peer { diff --git a/src/rendezvous_server.rs b/src/rendezvous_server.rs index 80452d4..b766d17 100644 --- a/src/rendezvous_server.rs +++ b/src/rendezvous_server.rs @@ -124,7 +124,6 @@ impl PeerMap { } const REG_TIMEOUT: i32 = 30_000; -pub const LICENSE_KEY: &'static str = ""; type Sink = SplitSink, Bytes>; type Sender = mpsc::UnboundedSender<(RendezvousMessage, SocketAddr)>; static mut ROTATION_RELAY_SERVER: usize = 0; @@ -150,6 +149,7 @@ impl RendezvousServer { serial: i32, rendezvous_servers: Vec, software_url: String, + license: &str, stop: Arc>, ) -> ResultType<()> { let mut socket = FramedSocket::new(addr).await?; @@ -183,7 +183,7 @@ impl RendezvousServer { allow_err!(socket.send(&msg, addr).await); } Some(Ok((bytes, addr))) = socket.next() => { - allow_err!(rs.handle_msg(&bytes, addr, &mut socket).await); + allow_err!(rs.handle_msg(&bytes, addr, &mut socket, license).await); } Ok((stream, addr)) = listener2.accept() => { let stream = FramedStream::from(stream); @@ -208,6 +208,7 @@ impl RendezvousServer { let (a, mut b) = Framed::new(stream, BytesCodec::new()).split(); let tcp_punch = rs.tcp_punch.clone(); let mut rs = rs.clone(); + let license = license.to_owned(); tokio::spawn(async move { let mut sender = Some(a); while let Ok(Some(Ok(bytes))) = timeout(30_000, b.next()).await { @@ -218,7 +219,7 @@ impl RendezvousServer { if let Some(sender) = sender.take() { tcp_punch.lock().unwrap().insert(addr, sender); } - allow_err!(rs.handle_tcp_punch_hole_request(addr, ph).await); + allow_err!(rs.handle_tcp_punch_hole_request(addr, ph, &license).await); } Some(rendezvous_message::Union::request_relay(mut rf)) => { // there maybe several attempt, so sender can be none @@ -299,6 +300,7 @@ impl RendezvousServer { bytes: &BytesMut, addr: SocketAddr, socket: &mut FramedSocket, + license: &str, ) -> ResultType<()> { if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { match msg_in.union { @@ -348,12 +350,14 @@ impl RendezvousServer { } Some(rendezvous_message::Union::punch_hole_request(ph)) => { if self.pm.is_in_memory(&ph.id) { - self.handle_udp_punch_hole_request(addr, ph).await?; + self.handle_udp_punch_hole_request(addr, ph, license) + .await?; } else { // not in memory, fetch from db with spawn in case blocking me let mut me = self.clone(); + let license = license.to_owned(); tokio::spawn(async move { - allow_err!(me.handle_udp_punch_hole_request(addr, ph).await); + allow_err!(me.handle_udp_punch_hole_request(addr, ph, &license).await); }); } } @@ -526,8 +530,9 @@ impl RendezvousServer { &mut self, addr: SocketAddr, ph: PunchHoleRequest, + license: &str, ) -> ResultType<(RendezvousMessage, Option)> { - if ph.licence_key != LICENSE_KEY { + if !license.is_empty() && ph.licence_key != license { let mut msg_out = RendezvousMessage::new(); msg_out.set_punch_hole_response(PunchHoleResponse { failure: punch_hole_response::Failure::LICENCE_MISMATCH.into(), @@ -639,8 +644,9 @@ impl RendezvousServer { &mut self, addr: SocketAddr, ph: PunchHoleRequest, + license: &str, ) -> ResultType<()> { - let (msg, to_addr) = self.handle_punch_hole_request(addr, ph).await?; + let (msg, to_addr) = self.handle_punch_hole_request(addr, ph, license).await?; if let Some(addr) = to_addr { self.tx.send((msg, addr))?; } else { @@ -654,8 +660,9 @@ impl RendezvousServer { &mut self, addr: SocketAddr, ph: PunchHoleRequest, + license: &str, ) -> ResultType<()> { - let (msg, to_addr) = self.handle_punch_hole_request(addr, ph).await?; + let (msg, to_addr) = self.handle_punch_hole_request(addr, ph, license).await?; self.tx.send(( msg, match to_addr {