合并更新网络库
智能大石头 authored at 2024-02-02 16:28:58 智能大石头 committed at 2024-02-02 16:57:26
7.52 KiB
X
using System.Net;
using System.Net.Sockets;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Text;
using NewLife.Reflection;

namespace NewLife.Net;

/// <summary>Socket扩展</summary>
public static class SocketHelper
{
    /// <summary>异步发送数据</summary>
    /// <param name="socket"></param>
    /// <param name="buffer"></param>
    /// <returns></returns>
    public static Task<Int32> SendAsync(this Socket socket, Byte[] buffer)
    {
        var task = Task<Int32>.Factory.FromAsync((Byte[] buf, AsyncCallback callback, Object? state) =>
        {
            return socket.BeginSend(buf, 0, buf.Length, SocketFlags.None, callback, state);
        }, socket.EndSend, buffer, null);

        return task;
    }

    /// <summary>异步发送数据</summary>
    /// <param name="socket"></param>
    /// <param name="buffer"></param>
    /// <param name="remote"></param>
    /// <returns></returns>
    public static Task<Int32> SendToAsync(this Socket socket, Byte[] buffer, IPEndPoint remote)
    {
        var task = Task<Int32>.Factory.FromAsync((Byte[] buf, IPEndPoint ep, AsyncCallback callback, Object? state) =>
        {
            return socket.BeginSendTo(buf, 0, buf.Length, SocketFlags.None, ep, callback, state);
        }, socket.EndSendTo, buffer, remote, null);

        return task;
    }

    /// <summary>发送数据流</summary>
    /// <param name="socket"></param>
    /// <param name="stream"></param>
    /// <param name="remoteEP"></param>
    /// <returns>返回自身,用于链式写法</returns>
    public static Socket Send(this Socket socket, Stream stream, IPEndPoint? remoteEP = null)
    {
        Int64 total = 0;
        remoteEP ??= socket.RemoteEndPoint as IPEndPoint;
        if (remoteEP == null) throw new ArgumentNullException(nameof(remoteEP));

        var size = 1472;
        var buffer = new Byte[size];
        while (true)
        {
            var n = stream.Read(buffer, 0, buffer.Length);
            if (n <= 0) break;

            socket.SendTo(buffer, 0, n, SocketFlags.None, remoteEP);
            total += n;

            if (n < buffer.Length) break;
        }
        return socket;
    }

    /// <summary>向指定目的地发送信息</summary>
    /// <param name="socket"></param>
    /// <param name="buffer">缓冲区</param>
    /// <param name="remoteEP"></param>
    /// <returns>返回自身,用于链式写法</returns>
    public static Socket Send(this Socket socket, Byte[] buffer, IPEndPoint? remoteEP = null)
    {
        remoteEP ??= socket.RemoteEndPoint as IPEndPoint;
        if (remoteEP == null) throw new ArgumentNullException(nameof(remoteEP));

        socket.SendTo(buffer, 0, buffer.Length, SocketFlags.None, remoteEP);

        return socket;
    }

    /// <summary>向指定目的地发送信息</summary>
    /// <param name="socket"></param>
    /// <param name="message"></param>
    /// <param name="encoding">文本编码,默认null表示UTF-8编码</param>
    /// <param name="remoteEP"></param>
    /// <returns>返回自身,用于链式写法</returns>
    public static Socket Send(this Socket socket, String message, Encoding? encoding = null, IPEndPoint? remoteEP = null)
    {
        if (encoding == null)
            Send(socket, Encoding.UTF8.GetBytes(message), remoteEP);
        else
            Send(socket, encoding.GetBytes(message), remoteEP);
        return socket;
    }

    /// <summary>广播数据包</summary>
    /// <param name="socket"></param>
    /// <param name="buffer">缓冲区</param>
    /// <param name="port"></param>
    public static Socket Broadcast(this Socket socket, Byte[] buffer, Int32 port)
    {
        if (/*socket != null &&*/ socket.LocalEndPoint != null)
        {
            var ip = socket.LocalEndPoint as IPEndPoint;
            if (ip != null && !ip.Address.IsIPv4()) throw new NotSupportedException("IPv6 does not support broadcasting!");
        }

        if (!socket.EnableBroadcast) socket.EnableBroadcast = true;

        socket.SendTo(buffer, 0, buffer.Length, SocketFlags.None, new IPEndPoint(IPAddress.Broadcast, port));

        return socket;
    }

    /// <summary>广播字符串</summary>
    /// <param name="socket"></param>
    /// <param name="message"></param>
    /// <param name="port"></param>
    public static Socket Broadcast(this Socket socket, String message, Int32 port)
    {
        var buffer = Encoding.UTF8.GetBytes(message);
        return Broadcast(socket, buffer, port);
    }

    /// <summary>接收字符串</summary>
    /// <param name="socket"></param>
    /// <param name="encoding">文本编码,默认null表示UTF-8编码</param>
    /// <returns></returns>
    public static String ReceiveString(this Socket socket, Encoding? encoding = null)
    {
        EndPoint ep = new IPEndPoint(IPAddress.Any, 0);

        var buf = new Byte[1460];
        var len = socket.ReceiveFrom(buf, ref ep);
        if (len < 1) return String.Empty;

        encoding ??= Encoding.UTF8;
        return encoding.GetString(buf, 0, len);
    }

    /// <summary>检查并开启广播</summary>
    /// <param name="socket"></param>
    /// <param name="address"></param>
    /// <returns></returns>
    internal static Socket CheckBroadcast(this Socket socket, IPAddress address)
    {
        var buf = address.GetAddressBytes();
        if (buf?.Length == 4 && buf[3] == 255)
        {
            if (!socket.EnableBroadcast) socket.EnableBroadcast = true;
        }

        return socket;
    }

    #region 关闭连接
    /// <summary>关闭连接</summary>
    /// <param name="socket"></param>
    /// <param name="reuseAddress"></param>
    public static void Shutdown(this Socket socket, Boolean reuseAddress = false)
    {
        if (socket == null || mSafeHandle == null) return;

        var value = socket.GetValue(mSafeHandle);
        if (value is not SafeHandle hand || hand.IsClosed) return;

        // 先用Shutdown禁用Socket(发送未完成发送的数据),再用Close关闭,这是一种比较优雅的关闭Socket的方法
        if (socket.Connected)
        {
            try
            {
                socket.Disconnect(reuseAddress);
                socket.Shutdown(SocketShutdown.Both);
            }
            catch (Exception) { }
        }

        socket.Close();
    }

    private static MemberInfo?[]? _mSafeHandle;
    /// <summary>SafeHandle字段</summary>
    private static MemberInfo? mSafeHandle
    {
        get
        {
            if (_mSafeHandle != null && _mSafeHandle.Length > 0) return _mSafeHandle[0];

            MemberInfo? pi = typeof(Socket).GetFieldEx("m_Handle");
            pi ??= typeof(Socket).GetPropertyEx("SafeHandle");
            _mSafeHandle = [pi];

            return pi;
        }
    }
    #endregion

    #region 异步事件
    /// <summary>Socket是否未被关闭</summary>
    /// <param name="se"></param>
    /// <returns></returns>
    internal static Boolean IsNotClosed(this SocketAsyncEventArgs se) => se.SocketError is SocketError.OperationAborted or SocketError.Interrupted or SocketError.NotSocket;

    /// <summary>根据异步事件获取可输出异常,屏蔽常见异常</summary>
    /// <param name="se"></param>
    /// <returns></returns>
    internal static Exception? GetException(this SocketAsyncEventArgs se)
    {
        if (se == null) return null;

        if (se.SocketError is SocketError.ConnectionReset or
            SocketError.OperationAborted or
            SocketError.Interrupted or
            SocketError.NotSocket)
            return null;

        return new SocketException((Int32)se.SocketError);
    }
    #endregion
}