mirror of
https://github.com/rustdesk/hbb_common.git
synced 2025-07-03 00:17:17 +00:00
[fix bug] fix twice enc.
- fix twice enc. - enable proxy (temp ver.) - rewite websocket::next.
This commit is contained in:
parent
2d65c24e4b
commit
d299e4909f
@ -120,7 +120,7 @@ impl Stream {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Stream {
|
impl Stream {
|
||||||
/// 从 TCP 流创建(自动判断是否升级为 WebSocket)
|
/// establish connect from tcp.
|
||||||
pub async fn from_tcp(
|
pub async fn from_tcp(
|
||||||
stream: tokio::net::TcpStream,
|
stream: tokio::net::TcpStream,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
@ -135,7 +135,7 @@ impl Stream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 创建 WebSocket 客户端连接
|
/// establish connect from websocket
|
||||||
pub async fn connect_websocket(
|
pub async fn connect_websocket(
|
||||||
url: impl AsRef<str>,
|
url: impl AsRef<str>,
|
||||||
local_addr: Option<SocketAddr>,
|
local_addr: Option<SocketAddr>,
|
||||||
@ -147,7 +147,7 @@ impl Stream {
|
|||||||
Ok(Self::WebSocket(ws_stream))
|
Ok(Self::WebSocket(ws_stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 发送消息
|
/// send message
|
||||||
pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> {
|
pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> {
|
||||||
match self {
|
match self {
|
||||||
Self::WebSocket(ws) => ws.send(msg).await,
|
Self::WebSocket(ws) => ws.send(msg).await,
|
||||||
@ -155,7 +155,7 @@ impl Stream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 接收消息
|
/// receive message
|
||||||
pub async fn next(&mut self) -> Option<Result<bytes::BytesMut, std::io::Error>> {
|
pub async fn next(&mut self) -> Option<Result<bytes::BytesMut, std::io::Error>> {
|
||||||
match self {
|
match self {
|
||||||
Self::WebSocket(ws) => ws.next().await,
|
Self::WebSocket(ws) => ws.next().await,
|
||||||
@ -163,7 +163,6 @@ impl Stream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 其他必要的方法...
|
|
||||||
pub fn local_addr(&self) -> SocketAddr {
|
pub fn local_addr(&self) -> SocketAddr {
|
||||||
match self {
|
match self {
|
||||||
Self::WebSocket(ws) => ws.local_addr(),
|
Self::WebSocket(ws) => ws.local_addr(),
|
||||||
|
@ -11,7 +11,6 @@ use std::{
|
|||||||
net::SocketAddr,
|
net::SocketAddr,
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream};
|
|
||||||
use tokio::{net::TcpStream, time::timeout};
|
use tokio::{net::TcpStream, time::timeout};
|
||||||
use tokio_tungstenite::{
|
use tokio_tungstenite::{
|
||||||
connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream,
|
connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream,
|
||||||
@ -26,26 +25,58 @@ pub struct WsFramedStream {
|
|||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
encrypt: Option<Encrypt>,
|
encrypt: Option<Encrypt>,
|
||||||
send_timeout: u64,
|
send_timeout: u64,
|
||||||
read_buf: BytesMut,
|
// read_buf: BytesMut,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WsFramedStream {
|
impl WsFramedStream {
|
||||||
pub async fn new<T: AsRef<str>>(
|
pub async fn new<T: AsRef<str>>(
|
||||||
url: T,
|
url: T,
|
||||||
local_addr: Option<SocketAddr>,
|
local_addr: Option<SocketAddr>,
|
||||||
_proxy_conf: Option<&Socks5Server>,
|
proxy_conf: Option<&Socks5Server>,
|
||||||
ms_timeout: u64,
|
ms_timeout: u64,
|
||||||
) -> ResultType<Self> {
|
) -> ResultType<Self> {
|
||||||
let (stream, _) = connect_async(url.as_ref()).await?;
|
let url_str = url.as_ref();
|
||||||
|
|
||||||
|
if let Some(proxy_conf) = proxy_conf {
|
||||||
|
// use proxy connect
|
||||||
|
let url_obj = url::Url::parse(url_str)?;
|
||||||
|
let host = url_obj
|
||||||
|
.host_str()
|
||||||
|
.ok_or_else(|| Error::new(ErrorKind::Other, "Invalid URL: no host"))?;
|
||||||
|
|
||||||
|
let port = url_obj
|
||||||
|
.port()
|
||||||
|
.unwrap_or(if url_obj.scheme() == "wss" { 443 } else { 80 });
|
||||||
|
|
||||||
|
let socket =
|
||||||
|
tokio_socks::tcp::Socks5Stream::connect(proxy_conf.proxy.as_str(), (host, port))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let tcp_stream = socket.into_inner();
|
||||||
|
let maybe_tls_stream = MaybeTlsStream::Plain(tcp_stream);
|
||||||
|
let ws_stream =
|
||||||
|
WebSocketStream::from_raw_socket(maybe_tls_stream, Role::Client, None).await;
|
||||||
|
|
||||||
|
let addr = match ws_stream.get_ref() {
|
||||||
|
MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?,
|
||||||
|
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
stream: ws_stream,
|
||||||
|
addr,
|
||||||
|
encrypt: None,
|
||||||
|
send_timeout: ms_timeout,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
let (stream, _) = connect_async(url_str).await?;
|
||||||
|
|
||||||
// 获取底层TCP流的peer_addr
|
|
||||||
let addr = match stream.get_ref() {
|
let addr = match stream.get_ref() {
|
||||||
MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?,
|
MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?,
|
||||||
#[cfg(feature = "native-tls")]
|
#[cfg(feature = "native-tls")]
|
||||||
MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_addr()?,
|
MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_addr()?,
|
||||||
#[cfg(feature = "rustls")]
|
#[cfg(feature = "rustls")]
|
||||||
MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?,
|
MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?,
|
||||||
// 处理其他可能的情况
|
|
||||||
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
|
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -54,14 +85,12 @@ impl WsFramedStream {
|
|||||||
addr,
|
addr,
|
||||||
encrypt: None,
|
encrypt: None,
|
||||||
send_timeout: ms_timeout,
|
send_timeout: ms_timeout,
|
||||||
read_buf: BytesMut::new(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_raw(&mut self) {
|
|
||||||
// WebSocket不需要特殊处理,保持空实现
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_raw(&mut self) {}
|
||||||
|
|
||||||
pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType<Self> {
|
pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType<Self> {
|
||||||
let ws_stream =
|
let ws_stream =
|
||||||
WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Server, None)
|
WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Server, None)
|
||||||
@ -72,16 +101,13 @@ impl WsFramedStream {
|
|||||||
addr,
|
addr,
|
||||||
encrypt: None,
|
encrypt: None,
|
||||||
send_timeout: 0,
|
send_timeout: 0,
|
||||||
read_buf: BytesMut::new(),
|
// read_buf: BytesMut::new(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn from(stream: TcpStream, addr: SocketAddr) -> Self {
|
pub async fn from(stream: TcpStream, addr: SocketAddr) -> Self {
|
||||||
let ws_stream = WebSocketStream::from_raw_socket(
|
let ws_stream =
|
||||||
MaybeTlsStream::Plain(stream), // 包装为MaybeTlsStream
|
WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None)
|
||||||
Role::Client,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
@ -89,7 +115,7 @@ impl WsFramedStream {
|
|||||||
addr,
|
addr,
|
||||||
encrypt: None,
|
encrypt: None,
|
||||||
send_timeout: 0,
|
send_timeout: 0,
|
||||||
read_buf: BytesMut::new(),
|
// read_buf: BytesMut::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,14 +151,7 @@ impl WsFramedStream {
|
|||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
|
pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
|
||||||
// 转换为Vec<u8>时需要处理加密
|
let msg = WsMessage::Binary(Bytes::from(bytes));
|
||||||
let data = if let Some(key) = self.encrypt.as_mut() {
|
|
||||||
key.enc(&bytes.to_vec())
|
|
||||||
} else {
|
|
||||||
bytes.to_vec()
|
|
||||||
};
|
|
||||||
|
|
||||||
let msg = WsMessage::Binary(Bytes::from(data));
|
|
||||||
if self.send_timeout > 0 {
|
if self.send_timeout > 0 {
|
||||||
let send_future = self.stream.send(msg);
|
let send_future = self.stream.send(msg);
|
||||||
timeout(Duration::from_millis(self.send_timeout), send_future)
|
timeout(Duration::from_millis(self.send_timeout), send_future)
|
||||||
@ -151,38 +170,22 @@ impl WsFramedStream {
|
|||||||
#[inline]
|
#[inline]
|
||||||
pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> {
|
pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> {
|
||||||
loop {
|
loop {
|
||||||
if let Some((frame, _)) = self.read_buf.split_first() {
|
|
||||||
if let Some(decrypted) = self.try_decrypt() {
|
|
||||||
return Some(Ok(decrypted));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
match self.stream.next().await? {
|
match self.stream.next().await? {
|
||||||
Ok(WsMessage::Binary(data)) => {
|
Ok(WsMessage::Binary(data)) => {
|
||||||
self.read_buf.extend_from_slice(&data);
|
let mut bytes = BytesMut::from(&data[..]);
|
||||||
if let Some(decrypted) = self.try_decrypt() {
|
if let Some(key) = self.encrypt.as_mut() {
|
||||||
return Some(Ok(decrypted));
|
if let Err(e) = key.dec(&mut bytes) {
|
||||||
|
return Some(Err(e));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(_) => continue, // 忽略非二进制消息
|
return Some(Ok(bytes));
|
||||||
|
}
|
||||||
|
Ok(_) => continue,
|
||||||
Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))),
|
Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn try_decrypt(&mut self) -> Option<BytesMut> {
|
|
||||||
if let Some(key) = self.encrypt.as_mut() {
|
|
||||||
if let Ok(()) = key.dec(&mut self.read_buf) {
|
|
||||||
let data = self.read_buf.split();
|
|
||||||
return Some(data);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let data = self.read_buf.split();
|
|
||||||
return Some(data);
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<BytesMut, Error>> {
|
pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<BytesMut, Error>> {
|
||||||
match timeout(Duration::from_millis(ms), self.next()).await {
|
match timeout(Duration::from_millis(ms), self.next()).await {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user