[enhance] add ping pong send.

This commit is contained in:
YinMo19 2025-04-23 10:42:34 +08:00
parent 0d6948c97b
commit f31d1ec1b8

View File

@ -5,12 +5,15 @@ use crate::{
ResultType, ResultType,
}; };
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use std::sync::Arc;
use std::{ use std::{
io::{Error, ErrorKind}, io::{Error, ErrorKind},
net::SocketAddr, net::SocketAddr,
time::Duration, time::Duration,
}; };
use tokio::sync::Mutex;
use tokio::{net::TcpStream, time::timeout}; use tokio::{net::TcpStream, time::timeout};
use tokio_tungstenite::{ use tokio_tungstenite::{
connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream,
@ -22,7 +25,9 @@ use tungstenite::protocol::Role;
pub struct Encrypt(Key, u64, u64); pub struct Encrypt(Key, u64, u64);
pub struct WsFramedStream { pub struct WsFramedStream {
stream: WebSocketStream<MaybeTlsStream<TcpStream>>, // stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
writer: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, WsMessage>>>,
reader: futures::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
addr: SocketAddr, addr: SocketAddr,
encrypt: Option<Encrypt>, encrypt: Option<Encrypt>,
send_timeout: u64, send_timeout: u64,
@ -30,6 +35,9 @@ pub struct WsFramedStream {
} }
impl WsFramedStream { impl WsFramedStream {
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(15);
pub async fn new<T: AsRef<str>>( pub async fn new<T: AsRef<str>>(
url: T, url: T,
local_addr: Option<SocketAddr>, local_addr: Option<SocketAddr>,
@ -63,16 +71,22 @@ impl WsFramedStream {
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
}; };
Ok(Self { let (writer, reader) = ws_stream.split();
stream: ws_stream,
let mut ws = Self {
writer: Arc::new(Mutex::new(writer)),
reader,
addr, addr,
encrypt: None, encrypt: None,
send_timeout: ms_timeout, send_timeout: ms_timeout,
}) };
ws.start_heartbeat();
Ok(ws)
} else { } else {
log::info!("{:?}", url_str); log::info!("{:?}", url_str);
let mut request = url_str let request = url_str
.into_client_request() .into_client_request()
.map_err(|e| Error::new(ErrorKind::Other, e))?; .map_err(|e| Error::new(ErrorKind::Other, e))?;
@ -90,15 +104,36 @@ impl WsFramedStream {
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
}; };
Ok(Self { let (writer, reader) = stream.split();
stream, let mut ws = Self {
writer: Arc::new(Mutex::new(writer)),
reader,
addr, addr,
encrypt: None, encrypt: None,
send_timeout: ms_timeout, send_timeout: ms_timeout,
}) };
ws.start_heartbeat();
Ok(ws)
} }
} }
fn start_heartbeat(&self) {
let writer = Arc::clone(&self.writer);
tokio::spawn(async move {
let mut interval = tokio::time::interval(Self::HEARTBEAT_INTERVAL);
loop {
interval.tick().await;
let mut lock = writer.lock().await;
if let Err(e) = lock.send(WsMessage::Ping(Bytes::new())).await {
log::error!("Failed to send ping: {}", e);
break;
}
drop(lock); // 及时释放锁
}
});
}
pub fn set_raw(&mut self) {} pub fn set_raw(&mut self) {}
pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType<Self> { pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType<Self> {
@ -106,12 +141,14 @@ impl WsFramedStream {
WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None)
.await; .await;
let (writer, reader) = ws_stream.split();
Ok(Self { Ok(Self {
stream: ws_stream, writer: Arc::new(Mutex::new(writer)),
reader,
addr, addr,
encrypt: None, encrypt: None,
send_timeout: 0, send_timeout: 0,
// read_buf: BytesMut::new(),
}) })
} }
@ -120,12 +157,14 @@ impl WsFramedStream {
WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None)
.await; .await;
let (writer, reader) = ws_stream.split();
Self { Self {
stream: ws_stream, writer: Arc::new(Mutex::new(writer)),
reader,
addr, addr,
encrypt: None, encrypt: None,
send_timeout: 0, send_timeout: 0,
// read_buf: BytesMut::new(),
} }
} }
@ -157,52 +196,84 @@ impl WsFramedStream {
#[inline] #[inline]
pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
let msg = WsMessage::Binary(Bytes::from(bytes)); let msg = WsMessage::Binary(bytes);
let mut writer = self.writer.lock().await;
if self.send_timeout > 0 { if self.send_timeout > 0 {
let send_future = self.stream.send(msg); timeout(Duration::from_millis(self.send_timeout), writer.send(msg)).await??
timeout(Duration::from_millis(self.send_timeout), send_future)
.await
.map_err(|_| Error::new(ErrorKind::TimedOut, "Send timeout"))?
.map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?;
} else { } else {
self.stream writer.send(msg).await?
.send(msg) };
.await
.map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?;
}
Ok(()) Ok(())
} }
#[inline] #[inline]
pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> { pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> {
log::info!("test"); log::debug!("Waiting for next message");
let start = std::time::Instant::now();
loop { loop {
match self.stream.next().await? { match self.reader.next().await {
Ok(WsMessage::Binary(data)) => { Some(Ok(msg)) => {
let mut bytes = BytesMut::from(&data[..]); log::debug!("Received message: {:?}", &msg);
if let Some(key) = self.encrypt.as_mut() { match msg {
if let Err(e) = key.dec(&mut bytes) { WsMessage::Binary(data) => {
return Some(Err(e)); log::info!("Received binary data ({} bytes)", data.len());
let mut bytes = BytesMut::from(&data[..]);
if let Some(key) = self.encrypt.as_mut() {
log::debug!("Decrypting data with seq: {}", key.2);
match key.dec(&mut bytes) {
Ok(_) => {
log::debug!("Decryption successful");
return Some(Ok(bytes));
}
Err(e) => {
log::error!("Decryption failed: {}", e);
return Some(Err(e));
}
}
}
return Some(Ok(bytes));
}
WsMessage::Ping(ping) => {
log::info!("Received ping ({} bytes)", ping.len());
let mut writer = self.writer.lock().await;
if let Err(e) = writer.send(WsMessage::Pong(ping)).await {
log::error!("Failed to send pong: {}", e);
return Some(Err(Error::new(
ErrorKind::Other,
format!("Failed to send pong: {}", e),
)));
}
log::debug!("Pong sent");
}
WsMessage::Pong(_) => {
log::debug!("Received pong");
}
WsMessage::Close(frame) => {
log::info!("Connection closed: {:?}", frame);
return None;
}
_ => {
log::warn!("Unhandled message :{}", &msg);
} }
} }
return Some(Ok(bytes));
} }
Ok(WsMessage::Ping(ping)) => { Some(Err(e)) => {
if let Err(e) = self.stream.send(WsMessage::Pong(ping)).await { log::error!("WebSocket error: {}", e);
return Some(Err(Error::new( return Some(Err(Error::new(
ErrorKind::Other, ErrorKind::Other,
format!("Failed to send pong: {}", e), format!("Failed to send pong: {}", e),
))); )));
}
continue;
} }
Ok(WsMessage::Pong(_)) => { None => {
log::debug!("Received pong"); log::info!("Connection closed gracefully");
continue; return None;
} }
Ok(WsMessage::Close(_)) => return None, }
Ok(_) => continue,
Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))), if start.elapsed() > Self::HEARTBEAT_TIMEOUT {
log::warn!("No message received within heartbeat timeout");
return Some(Err(Error::new(ErrorKind::TimedOut, "Heartbeat timeout")));
} }
} }
} }