TFTP协议实现

TFTP协议简介

TFTP(Trivial File Transfer Protocol,简单文件传输协议)是 TCP/IP 协议族中的一个用来在客户机与服务器之间进行简单文件传输的协议,提供不复杂、开销不大的文件传输服务。TFTP 承载在 UDP 上,提供不可靠的数据流传输服务,不提供存取授权与认证机制,使用超时重传方式来保证数据的到达,与FTP 相比,TFTP的大小要小的多。现在最普遍使用的是第二版 TFTP(TFTP Version 2 , RFC 1350)

TFTP协议结构

因为TFTP使用UDP,而UDP使用IP,IP还可以使用其它本地通信方法。因此一个TFTP包中会有以下几段:本地媒介头,IP头,数据报头,TFTP头,剩下的就是TFTP数据了。TFTP在IP头中不指定任何数据,但是它使用UDP中的源和目标端口以及包长度域。由TFTP使用的包标记(TID)在这里被用做端口,因此TID必须介于0到65,535之间。
数据包的类型有以下五种:请求读RRQ,请求写WRQ,文件数据DATA,确认ACK,发生错误ERR。报文格式如图

TFTP工作过程

TFTP的工作过程很像停止等待协议,发送完一个文件块后就等待对方的确认,确认时应指明所确认的块号。发送完数据后在规定时间内收不到确认,就要重发数据PDU,发送确认PDU的一方若在规定时间内收不到下一个文件块,也要重发确认PDU。这样保证文件的传送不致因某一个数据报的丢失而告失败。

客户端执行步骤

发送RRQ或WRQ请求到服务器的69端口,等待一个数据包或者是ACK包。这个包将包含一个69以外的新端口号。接收到数据包,就以ACK包响应,接收到ACK包,就发送下一个数据包。准备处理超时错误或者是ERR包。

服务端执行步骤

发送RRQ或WRQ请求到服务器的69端口,等待一个数据包或者是ACK包。这个包将包含一个69以外的新端口号。接收到数据包,就以ACK包响应,接收到ACK包,就发送下一个数据包。准备处理超时错误或者是ERR包。

代码流程图

客户端源码

//UDP客户端创建TFTP请求
class UDPClient {

    // OPCODES 操作码
    private static final byte RRQ_CODE = 1;
    private static final byte WRQ_CODE = 2;
    private static final byte DATA_CODE = 3;
    private static final byte ACK_CODE = 4;
    private static final byte ERR_CODE = 5;

    // Connection variables
    // 数据报大小
    private static int PACKET_SIZE = 516;
    // 端口
    private static int PORT = 669;
    // 发送端口
    private static int SEND_PORT = 69;
    // 服务器端口
    private int server_TID;

    // 客户端套接字
    private static DatagramSocket clientSocket;
    // IP地址
    private static InetAddress IPAddress;

    // 请求套接字
    private static DatagramPacket request;
    // 文件输出流
    private static ByteArrayOutputStream output;

    // Error Codes
    private static String ERROR_1 = "File not found.";
    private static String ERROR_2 = "Access violation.";
    private static String ERROR_3 = "Disk Full.";
    private static String ERROR_4 = "Illegal TFTP operation.";
    private static String ERROR_5 = "Unknown Transfer ID.";
    private static String ERROR_6 = "File Already Exists.";
    private static String ERROR_7 = "No such user.";

    // Client storage (R/W) location
    private static File outputFileDir;

    /**
     * 
     * 主函数入口
     * 
     * @param args
     * @throws Exception
     * 
     * 
     */

    public static void main(String args[]) throws Exception {
        // 本地创建文件
        File currentDirectory = new File(new File("").getAbsolutePath());
        outputFileDir = new File(currentDirectory.getAbsolutePath() + "\\Client_Folder");
        outputFileDir.mkdirs();
        // 创建客户端对象并调用run函数
        UDPClient testClient = new UDPClient();
        testClient.run();
    }

    /**
     * 客户端运行函数
     */
    private void run() {

        // 获取字符缓冲区输入
        BufferedReader inFromUser = new BufferedReader(new InputStreamReader(System.in));

        // 得到 IP
        try {
            // IPAddress = InetAddress.getByName("localhost");
            IPAddress = InetAddress.getByName("172.20.10.3");
        } catch (UnknownHostException e1) {
        }

        // 轮询用户请求
        while (true) {

            // Assign a random TID for each request
            try {
                // 随机生成临时通信所用端口
                Random rand = new Random();
                PORT = rand.nextInt(6536);
                // 不为空且已绑定则关闭重启一个新的clientsocket
                if (clientSocket != null && clientSocket.isBound()) {
                    clientSocket.close();
                }
                clientSocket = new DatagramSocket(PORT);

            } catch (SocketException e1) {
                return;
            }

            System.out.println("Enter command: r/w filename mode");
            // Read user input
            String sentence;
            try {
                // 从输入缓冲区中读取一行
                sentence = inFromUser.readLine();
                System.out.println(sentence);

                // 判断退出
                if (sentence == "quit") {
                    break;
                }

                // 获取用户输入的命令的三部分
                String[] inputs = sentence.split(" ");
                // 判断输入是否非法
                if (inputs.length == 3 && (inputs[0].equals("r") || inputs[0].equals("w"))
                        && (inputs[2].toLowerCase().equals("octet") || inputs[2].toLowerCase().equals("netascii"))) {
                    String fileName = inputs[1];
                    String mode = inputs[2];
                    mode = mode.toLowerCase();

                    // 读
                    if (inputs[0].equals("r")) {
                        // 创建和发送 RRQ 请求
                        DatagramPacket request1 = buildRequestPacket(RRQ_CODE, fileName, mode);
                        clientSocket.send(request1);
                        // 从服务器获取文件
                        get(inputs[1], inputs[2]);
                    }
                    // 写
                    else if (inputs[0].equals("w")) {
                        // 创建和发送 WRQ 请求数据报包
                        DatagramPacket request1 = buildRequestPacket(WRQ_CODE, fileName, mode);
                        clientSocket.send(request1);

                        // 创建ACK数据报
                        byte[] tempArray = new byte[516];
                        DatagramPacket ackPacket = new DatagramPacket(tempArray, tempArray.length);

                        // 轮询服务器ACK响应
                        while (true) {
                            try {
                                clientSocket.setSoTimeout(600000);
                                clientSocket.receive(ackPacket);
                                server_TID = ackPacket.getPort();
                                // 接收到ACK以后发送文件
                                sendFile(inputs[1], inputs[2]);
                                break;
                            } catch (SocketTimeoutException e) {
                                // 超时报错
                                System.out.println("Error, no response from server.");
                            }
                        }
                    }
                }
                // 输入错误
                else {
                    System.out.println("Invalid input");
                }
            } catch (Exception e) {
                System.out.println("Error, invalid input");
            }
        }

        // 完毕关闭客户端端口
        clientSocket.close();

    }

    /**
     * 从服务器读取文件
     * 
     * @param fileName
     * @param mode
     * @throws Exception
     */
    private void get(String fileName, String mode) throws Exception {

        // 收到数据报
        output = receivePackets(fileName, mode);
        // 如果收到数据报,写文件到本地
        if (output != null) {
            writeFile(output, (fileName));
        }
    }

    // 创建请求数据报
    DatagramPacket buildRequestPacket(byte requestType, String fileName, String mode) {

        /**
         *     01/02 
         * -------------------------------------------------------------- | Opcode
         * |  Filename |     0   |     Mode    |    0   |
         * --------------------------------------------------------------- 2 bytes
         *     string    1 byte       string      1 byte
         * 
         */

        byte zero = 0;
        byte[] fileNameBytes = fileName.getBytes();
        byte[] modeBytes = mode.getBytes();
        int requestLength = 2 + (fileNameBytes.length) + 1 + (modeBytes.length) + 1;

        // 穿件request数据报缓冲区并填入各字段值
        ByteBuffer requestBuffer = ByteBuffer.allocate(requestLength);
        requestBuffer.put(zero);
        requestBuffer.put(requestType);
        requestBuffer.put(fileNameBytes);
        requestBuffer.put(zero);
        requestBuffer.put(modeBytes);
        requestBuffer.put(zero);

        // 转为字节数组并构建数据报
        byte[] requestByteArray = requestBuffer.array();
        DatagramPacket requestPacket = new DatagramPacket(requestByteArray, requestByteArray.length, IPAddress,
                SEND_PORT);

        // 返回数据报
        return requestPacket;
    }

    // 发送ACK
    private void sendAcknowledgement(byte[] blockNumber, int portNumber) {

        // 创建和发送ACK数据报
        byte[] ackBytes = { (byte) 0, ACK_CODE, blockNumber[0], blockNumber[1] };
        DatagramPacket ackPacket = new DatagramPacket(ackBytes, ackBytes.length, IPAddress, server_TID);
        try {
            // 发送数据报
            clientSocket.send(ackPacket);
        } catch (IOException e) {
        }
    }

    // 从服务器收到数据报
    private ByteArrayOutputStream receivePackets(String fileName, String mode) throws IOException {

        ByteArrayOutputStream byteOutputStream = new ByteArrayOutputStream();
        int block = 1;
        boolean packetsRemaining = true;
        byte[] packetByteBuffer;
        DatagramPacket receivedPacket;
        boolean gotTID = false;

        // 有剩余则一直循环
        while (packetsRemaining) {

            // 判断磁盘是否满
            if (outputFileDir.getTotalSpace() < 516) {
                sendError(3, ERROR_3, IPAddress, server_TID);
                return null;
            }

            // 获取数据报
            packetByteBuffer = new byte[PACKET_SIZE];
            receivedPacket = new DatagramPacket(packetByteBuffer, packetByteBuffer.length);

            // Continue receiving until a valid packet is received
            while (true) {
                try {
                    // 5s 延时 接收并回填
                    clientSocket.setSoTimeout(5000);
                    clientSocket.receive(receivedPacket);

                    // Ignore invalid ports
                    if ((receivedPacket.getPort() != server_TID) && gotTID) {
                        sendError(5, ERROR_5, receivedPacket.getAddress(), receivedPacket.getPort());
                    } else {
                        break;
                    }
                } catch (SocketTimeoutException e) {
                    // 服务器无响应
                    if (block == 1) {
                        System.out.println("Error, no response from server.");
                        return null;
                    }
                    // 获取上一个block高8位和低8位
                    byte blockMSB = (byte) (((block - 1) >>> 8) & 0xFF);
                    byte blockLSB = (byte) (((block - 1)) & 0xFF);
                    byte[] blockBytes = { blockMSB, blockLSB };
                    // 重新发送ACK
                    sendAcknowledgement(blockBytes, server_TID);
                }
            }

            // 取得服务端端口
            if (!gotTID) {
                server_TID = receivedPacket.getPort();
                gotTID = true;
            }

            // netascii重新编码
            if (mode.equals("netascii")) {
                String temp = new String(packetByteBuffer);
                packetByteBuffer = Charset.forName("US-ASCII").encode(temp).array();
                temp = new String(packetByteBuffer);
            }

            // 取操作码判断类型
            byte opCode = packetByteBuffer[1];
            if (opCode == DATA_CODE) {

                // 检查长度,若小于516则为最后一个数据报,置循环条件为false
                if (receivedPacket.getLength() < 516) {
                    packetsRemaining = false;

                    // 判断文件是否已经存在
                    File fileOut1 = new File(outputFileDir.getAbsoluteFile() + "\\" + fileName);
                    if (fileOut1.exists()) {
                        sendError(6, ERROR_6, IPAddress, server_TID);
                        return null;
                    }
                }

                // 取得块号
                byte[] temp = { (packetByteBuffer[2]), packetByteBuffer[3] };
                int number = (((packetByteBuffer[2] & 0xFF) << 8) + (packetByteBuffer[3] & 0xFF)) & 0xFFFF;

                // 将数据部分写入输出流
                DataOutputStream dataOut = new DataOutputStream(byteOutputStream);
                dataOut.write(packetByteBuffer, 4, receivedPacket.getLength() - 4);

                // 发送ACK 块号+1
                sendAcknowledgement(temp, server_TID);
                block++;
            }
            // 错误信息数据报
            else if (opCode == ERR_CODE) {
                // 分片分出错误信息并打印输出
                packetByteBuffer[1] = (byte) 0;
                packetByteBuffer[3] = (byte) 0;
                byte[] zeroByteArray = { 0 };
                String zeroString = new String(zeroByteArray);
                String errorMessage = new String(packetByteBuffer);
                String[] splitMessage = errorMessage.split(zeroString);
                System.out.println("Error : " + splitMessage[4]);
                return null;
            } // 其他错误
            else {
                sendError(0, "undefined error", IPAddress, server_TID);
                return null;
            }
        }
        //返回输出流
        return byteOutputStream;
    }

    // 客户端写文件
    private void writeFile(ByteArrayOutputStream byteOutputStream, String fileName) {
        try {
            //写文件并关闭输出流
            OutputStream outputStream = new FileOutputStream(outputFileDir + "\\" + fileName);
            byteOutputStream.writeTo(outputStream);
            byteOutputStream.close();
            outputStream.close();

            // 完成校验和计算
            MessageDigest md = MessageDigest.getInstance("MD5");
            InputStream is = Files.newInputStream(Paths.get(outputFileDir + "\\" + fileName));
            DigestInputStream dis = new DigestInputStream(is, md);
            // 读到文件末尾
            while ((is.read()) != -1) {
            }
            String digest = new String(md.digest());
            System.out.println("checksum : " + digest);
            is.close();

        } catch (FileAlreadyExistsException e) {
            System.out.println("File already exists.");

        } catch (IOException e) {
            System.out.println("Error writing file.");
        } catch (NoSuchAlgorithmException e) {
            System.out.println("Error writing file.");
        }
    }

    // 向服务端发送文件
    void sendFile(String fileName, String mode) throws IOException {
        // 已经收到ACK的情况下 要向服务端发送第一块数据
        int block = 1;
        int bytesRead = 0;
        boolean serverTIDCollected = false;

        try {

            // 写文件
            File file = new File(outputFileDir + "\\" + fileName);
            long length = file.length();
            InputStream in = new FileInputStream(file);
            byte[] bytes = new byte[512];
            byte[] bytesToRead = new byte[516];

            // 只要仍然可读,缓冲区每次读512B
            while ((bytesRead = in.read(bytes, 0, 512)) != -1) {

                int size = bytesRead + 4;

                // 缓冲区写入 03+块号+数据 转为字节数组 组装成发送数据报
                ByteBuffer sendBuffer = ByteBuffer.allocate(size);
                sendBuffer.put((byte) 0);
                sendBuffer.put((byte) 3);
                sendBuffer.put((byte) (block >>> 8));
                sendBuffer.put((byte) (block & 0xFF));
                sendBuffer.put(bytes, 0, bytesRead);

                byte[] bytesToSend = sendBuffer.array();

                // netascii编码转换
                if (mode.equals("netascii")) {
                    String temp = new String(bytesToSend);
                    bytesToSend = Charset.forName("US-ASCII").encode(temp).array();
                    temp = new String(bytesToSend);
                }
                // 组装发送数据包 -----------bytesToSend
                DatagramPacket sendPacket = new DatagramPacket(bytesToSend, bytesToSend.length, IPAddress, server_TID);

                // 发送数据报
                clientSocket.send(sendPacket);
                // 组装接收数据包ACK -----------bytesToRead
                DatagramPacket receivePacket = new DatagramPacket(bytesToRead, bytesToRead.length);

                // 轮询等待响应
                while (true) {
                    try {
                        // 设置延时50s 接收ACK并回填bytesToRead
                        clientSocket.setSoTimeout(50000);
                        clientSocket.receive(receivePacket);
                        // 如果和服务端端口不合报错
                        if (receivePacket.getPort() != server_TID && serverTIDCollected) {
                            sendError(5, ERROR_5, receivePacket.getAddress(), receivePacket.getPort());
                        } else {
                            // 端口相同传输正确,跳出轮询
                            break;
                        }

                    } // 超时重传发送数据报
                    catch (SocketTimeoutException e) {
                        clientSocket.send(sendPacket);
                    }
                }

                // 获取服务端端口
                if (!serverTIDCollected) {
                    server_TID = receivePacket.getPort();
                    serverTIDCollected = true;
                }

                // netascii编码
                if (mode.equals("netascii")) {
                    String temp = new String(bytesToRead);
                    bytesToRead = Charset.forName("US-ASCII").encode(temp).array();
                    temp = new String(bytesToRead);
                }

                // 从服务器响应回填的bytesToRead中读出响应数据报类型
                switch (bytesToRead[1]) {
                // 读取ACK块号 块号不匹配报错 匹配则自增到下一块 退出switch
                case ACK_CODE:
                    int temp = ((bytesToRead[2] & 0xFF) << 8) + (bytesToRead[3] & 0xFF);
                    if (temp != block) {
                        sendError(0, "Incorrent Block", IPAddress, server_TID);
                    } else {
                        block++;
                    }
                    break;
                // 片出ERR错误消息并打印,跳出sendFile函数
                case ERR_CODE:
                    bytesToRead[1] = (byte) 0;
                    bytesToRead[3] = (byte) 0;
                    byte[] zeroByteArray = { 0 };
                    String zeroString = new String(zeroByteArray);
                    String errorMessage = new String(bytesToRead);
                    String[] splitMessage = errorMessage.split(zeroString);
                    System.out.println("Error : " + splitMessage[4]);
                    in.close();
                    return;

                // 不支持的类型报错
                default:
                    sendError(4, ERROR_4, IPAddress, server_TID);
                    in.close();
                    return;
                }
            }
            in.close();

        } catch (FileNotFoundException e) {
            sendError(1, ERROR_1, IPAddress, server_TID);
        } catch (IOException e) {
            System.out.println("Error reading file");
            return;
        }
    }

    // 发送错误数据报
    void sendError(int errorCode, String errorMessage, InetAddress address, int port) {

        // 打印错误信息 
        System.out.println("Error " + errorMessage);
        byte code = (byte) (errorCode & 0xFF);
        byte[] message = errorMessage.getBytes();

        /**   05        0+code     msg     0
         * -----------------------------------------------
         * | Opcode | ErrorCode | ErrMsg | 0 |
         * -----------------------------------------------
         * 2 bytes   2 bytes     string  1 byte
         */
        int size = 5 + message.length;
        ByteBuffer error = ByteBuffer.allocate(size);
        error.put((byte) 0);
        error.put((byte) 5);
        error.put((byte) 0);
        error.put(code);
        error.put(message);
        error.put((byte) 0);

        //组装成错误数据报并发送
        DatagramPacket errorPacket = new DatagramPacket(error.array(), error.array().length, address, port);
        try {
            clientSocket.send(errorPacket);
        } catch (IOException e) {
            System.out.println("Error sending error");
        }
        System.out.println("Client Error :" + errorMessage);
    }

}

服务端源码

class UDPServer {

    // OPCODES 报文识别码 byte类型
    // 请求读
    private static final byte RRQ = 1;
    // 请求写
    private static final byte WRQ = 2;
    // 传送数据
    private static final byte DATA = 3;
    // 确认报文
    private static final byte ACK = 4;
    // 错误报文
    private static final byte ERR = 5;

    // 数据报长度512 B + 头长度4 B = 516 B
    private static int PACKET_SIZE = 516;

    // Sockets
    // UDP协议的Socket 数据报
    // 服务端数据报
    static DatagramSocket serverSocket;
    // 请求数据报
    static DatagramSocket requestSocket;
    // 数字形式保存的IP地址
    private static InetAddress IPAddress;

    // Ports
    // 服务器端口69
    private static int SERVER_PORT = 69;
    // 客户端端口669
    private static int CLIENT_PORT = 669;

    // Error IDs 错误信息
    // 文件未找到
    private static String ERROR_1 = "File not found.";
    // 访问非法
    private static String ERROR_2 = "Access violation.";
    // 磁盘已满
    private static String ERROR_3 = "Disk Full.";
    // 非法TFTP操作
    private static String ERROR_4 = "Illegal TFTP operation.";
    // 未知传输ID
    private static String ERROR_5 = "Unknown Transfer ID.";
    // 文件已存在
    private static String ERROR_6 = "File Already Exists.";
    // 无此用户
    private static String ERROR_7 = "No such user.";

    // 字节数组输出流
    private static ByteArrayOutputStream output;

    // TIDs
    // 目的端口
    private int destination_TID;
    // 当前端口
    private int current_TID;

    // Server file storage location 输出文件目录
    static File outputFileDir;

    public static void main(String args[]) throws Exception {
        // UDPserver实例化
        UDPServer testServer = new UDPServer();
        // Create and locate the storage directory 根据当前绝对路径创建新文件
        File currentDirectory = new File(new File("").getAbsolutePath());
        // 创建输出文件Server_Folder
        outputFileDir = new File(currentDirectory.getAbsolutePath() + "\\Server_Folder\\");
        // 创建输出文件目录
        outputFileDir.mkdirs();
        // Run the server 跳到run函数
        testServer.run();

    }

    /**
     * 运行server端
     * 
     * @throws Exception
     */
    void run() throws Exception {

        System.out.println("Server Started");
        // 开启69端口处理客户端请求
        try {
            // 创建一个DatagramSocket实例,并将该对象绑定到本机默认IP地址、69端口。
            requestSocket = new DatagramSocket(69);
        } catch (SocketException e) {
            e.printStackTrace();
        }
        // Data arrays and packets
        // 请求数据字节数组
        byte[] requestData = new byte[516];
        // 构造 DatagramPacket,用来请求长度为 516 的数据包
        DatagramPacket request = new DatagramPacket(requestData, requestData.length);
        // 接收数据字节数组
        byte[] receiveData = new byte[100];
        // 发送数据字节数组
        byte[] sendData = new byte[1024];

        // 长轮询 等待请求
        while (true) {
            try {
                System.out.println("Server waiting for requests");
                // 等待请求并加载入receivePacket
                // 构造 DatagramPacket,用来接收长度为 100的数据包
                DatagramPacket receivePacket = new DatagramPacket(receiveData, receiveData.length);
                // 从此套接字接收数据报 UDP数据报就是套接字 若未接收成功则不往下运行下一行代码
                requestSocket.receive(receivePacket);

                System.out.println("Request Received");

                /**
                 * 
                 * WRQ/RRQ数据报格式
                 * 
                 * 
                 * 01 / 02 
                 * -------------------------------------------------------------- |
                 * Opcode | Filename | 0 | Mode | 0 |
                 * --------------------------------------------------------------- 2 bytes
                 * string 1 byte string 1 byte
                 * 
                 */

                // 决定接受数据报类型
                //接收数据字节数组类型,操作码为第2个字节
                byte opCodeReceived = receiveData[1];
                // 0比特数组 长度为1 内容为0
                byte[] zeroByte = { 0 };
                // 转为字符串型
                String zeroString = new String(zeroByte);
                // 接收数据字节数组第2个字节(9-16位)设置为0
                receiveData[1] = zeroByte[0];
                // 接收数据字符串
                String inputString = new String(receiveData);
                // 根据0比特字符串分段放入pieces数组
                String[] pieces = inputString.split(zeroString);
                // pieces数组第3段设置为文件
                String fileString = pieces[2];
                // pieces数组第4段设置为模式
                String mode = pieces[3];
                // 将大写字符转换为小写
                mode = mode.toLowerCase();

                // 获取客户端IP和端口
                // 返回某台机器的 IP 地址,此数据报将要发往该机器或者是从该机器接收到的
                IPAddress = receivePacket.getAddress();
                // 获取客户端端口号
                CLIENT_PORT = receivePacket.getPort();

                // 给服务端指定一个随机数 TID
                Random rand = new Random();
                // 随机生成0-6535的随机数
                current_TID = rand.nextInt(6535);
                // 和端口号相同自增
                if (current_TID == CLIENT_PORT) {
                    current_TID++;
                }
                // Opens the TID socket 创建数据报套接字并将其绑定到本地主机上的指定端口
                serverSocket = new DatagramSocket(current_TID);

                // Handles the received input type 输入类型分发
                switch (opCodeReceived) {
                // Read request 读请求
                case RRQ:
                    System.out.println("Sending File for read request");
                    // Sends a requested file to the user from storage
                    sendFile(fileString, mode);
                    break;
                // Write request 写请求
                case WRQ:
                    System.out.println("Receiving File from write request");
                    // Obtains and writes a user file to the server storage
                    get(fileString, mode);
                    break;
                // Handles error messages 处理错误信息
                case ERR:
                    // Formats input to obtain error message
                    receiveData[1] = (byte) 0;
                    receiveData[3] = (byte) 0;
                    byte[] zeroByteArray = { 0 };
                    String errorMessage = new String(receiveData);
                    String[] splitMessage = errorMessage.split(zeroString);
                    System.out.println("Error : " + splitMessage[4]);
                    break;
                // If unacknowledged data type is received, respond with an error message 未知类型
                default:
                    sendError(4, ERROR_4, IPAddress, CLIENT_PORT);
                    break;
                }
                // Close the socket when finished
                serverSocket.close();

                // QUit the server if the quit command is issued
                if (mode == "quit") {
                    break;
                }

            } catch (NullPointerException e) {
                sendError(0, "Undefined", IPAddress, CLIENT_PORT);
            }
        }

        // Close the server socket when finished
        requestSocket.close();

    }

    // Sends a file to a client
    /**
     * 发送文件
     * 
     * @param fileName
     * @param mode
     */
    void sendFile(String fileName, String mode) {

        // Number of blocks received and bytes read
        int block = 1;
        int bytesRead = 0;

        try {

            // Opens reference to desired file
            File file;
            try {
                file = new File(outputFileDir.getAbsolutePath() + "\\" + fileName);
            } catch (Exception e) {
                // Notify the user if it is not found, the search for clients again
                sendError(1, ERROR_1, IPAddress, CLIENT_PORT);
                return;
            }

            // If the file exists send it to the client
            if (file.exists()) {

                /**
                 * 
                 * ----------------------------------- | Opcode | Block # | Data |
                 * ----------------------------------- 2 bytes 2 bytes n bytes
                 * 
                 * 
                 */

                // File IO variables
                InputStream in = new FileInputStream(file);
                byte[] bytes = new byte[512];
                byte[] bytesToRead = new byte[516];

                // Checksum computation
                MessageDigest md = MessageDigest.getInstance("MD5");
                // 创建MD5输入流
                /**
                 * 
                 * 
                 * --------------------------------------------------------------------------------------------------------------------
                 * 
                 * 
                 */
                DigestInputStream dis = new DigestInputStream(in, md);

                // Read till the end of the file, storing read data into bytes, in 512 byte
                // chunks 一直读到尾部 读取512个字节存到从bytes[0]开始的字节缓冲区并保存 返回成功读取的字节数
                while ((bytesRead = in.read(bytes, 0, 512)) != -1) {
                    // Determine packet size
                    int size = bytesRead + 4;

                    // Buffer data into bytesToSend 从堆空间中分配一个容量大小为size 516的byte数组作为缓冲区的byte数据存储器
                    ByteBuffer sendBuffer = ByteBuffer.allocate(size);
                    // 相对写,向0的位置写入一个byte,并将postion+1,为下次读写作准备
                    sendBuffer.put((byte) 0);
                    sendBuffer.put((byte) 3);
                    // 无符号右移,忽略符号位,空位都以0补齐 写入block块数
                    sendBuffer.put((byte) (block >>> 8));
                    sendBuffer.put((byte) (block & 0xFF));
                    // 从src数组中的bytes到offset+length区域读取数据并使用相对写写入此byteBuffer
                    sendBuffer.put(bytes, 0, bytesRead);
                    byte[] bytesToSend = sendBuffer.array();

                    // If in netascii mode, encode data appropriately 仅支持txt文件
                    if (mode.equals("netascii")) {
                        String temp = new String(bytesToSend);
                        bytesToSend = Charset.forName("US-ASCII").encode(temp).array();
                        temp = new String(bytesToSend);
                    }
                    // Make Packet
                    DatagramPacket sendPacket = new DatagramPacket(bytesToSend, bytesToSend.length, IPAddress,
                            CLIENT_PORT);

                    // Send Packet
                    serverSocket.send(sendPacket);

                    // Wait for Acknowledgement of last block, otherwise resend last packet

                    DatagramPacket receivePacket = new DatagramPacket(bytesToRead, bytesToRead.length, IPAddress,
                            SERVER_PORT);
                    // Wait until a valid packet is received
                    while (true) {
                        try {
                            // Set timeout to 50s
                            serverSocket.setSoTimeout(50000);
                            serverSocket.receive(receivePacket);
                            // Ignore data from incorrect ports, and send them an error message
                            if (receivePacket.getPort() != CLIENT_PORT) {
                                sendError(5, ERROR_5, receivePacket.getAddress(), receivePacket.getPort());
                            } else {
                                break;
                            }
                        } catch (SocketTimeoutException e) {
                            serverSocket.send(sendPacket);
                        }
                    }

                    // Encode data if in netascii mode
                    if (mode.equals("netascii")) {
                        String temp = new String(bytesToRead);
                        bytesToRead = Charset.forName("US-ASCII").encode(temp).array();
                        temp = new String(bytesToRead);
                    }

                    // Determine received packet type and respond
                    switch (bytesToRead[1]) {
                    case ACK:
                        // Determine which block was acknowledged
                        int temp = ((bytesToRead[2] & 0xFF) << 8) + (bytesToRead[3] & 0xFF);
                        if (temp != block) {
                            System.out.println("INCORRECT BLOCK");
                        }
                        // Increment the block count
                        else {
                            block++;
                        }
                        break;
                    // Handle error messages
                    case ERR:
                        // Format data to obtain error message
                        bytesToRead[1] = (byte) 0;
                        bytesToRead[3] = (byte) 0;
                        byte[] zeroByteArray = { 0 };
                        String zeroString = new String(zeroByteArray);
                        String errorMessage = new String(bytesToRead);
                        String[] splitMessage = errorMessage.split(zeroString);
                        System.out.println("Error : " + splitMessage[4]);
                        return;
                    // Handle non accepted data type with an error message
                    default:
                        sendError(4, ERROR_4, IPAddress, CLIENT_PORT);
                        return;
                    }
                }

                // Display checksum and close streams
                String digest = new String(md.digest());
                System.out.println("Data chucksum: " + digest);
                in.close();

            } else {
                // Handle file not found errors
                sendError(1, ERROR_1, IPAddress, CLIENT_PORT);
            }
        }
        // Handle file not found errors
        catch (FileNotFoundException e) {
            sendError(1, ERROR_1, IPAddress, CLIENT_PORT);
        }
        // Handle file not found errors
        catch (IOException e) {
            sendError(1, ERROR_1, IPAddress, CLIENT_PORT);
            return;
        }
        // Handle checksum errors
        catch (NoSuchAlgorithmException e) {
            sendError(0, "Checksum error.", IPAddress, CLIENT_PORT);
            return;
        }
    }

    // Send an error message to the client
    void sendError(int errorCode, String errorMessage, InetAddress address, int port) {

        // Encode the error message
        byte code = (byte) (errorCode & 0xFF);
        byte[] message = errorMessage.getBytes();

        /**
         * 05 ----------------------------------------------- | Opcode | ErrorCode |
         * ErrMsg | 0 | ----------------------------------------------- 2 bytes 2 bytes
         * string 1 byte
         * 
         */
        // Buffer the error message as bytes
        int size = 5 + message.length;
        ByteBuffer error = ByteBuffer.allocate(size);
        error.put((byte) 0);
        error.put((byte) 5);
        error.put((byte) 0);
        error.put(code);
        error.put(message);
        error.put((byte) 0);

        // Create and send the error packet
        DatagramPacket errorPacket = new DatagramPacket(error.array(), error.array().length, address, port);
        try {
            // 服务端发送错误数据报
            serverSocket.send(errorPacket);
        } catch (IOException e) {
            System.out.println("Error sending error");
            // 错误返回
            return;
        }
        System.out.println("Server Error sent" + errorMessage);
    }

    // Write a file sent from a client to the servers storage
    private void get(String fileName, String mode) throws Exception {

        // Acknowledge the write request 发送0块的ACK
        byte[] block = { 0, 0 };
        sendAcknowledgement(block, CLIENT_PORT);
        // Receive the file
        output = receivePackets(fileName, mode);
        // If the file was received successfully, attempt to write it to a file
        if (output != null) {
            writeFile(output, (fileName));
        }
    }

    // Send an acknowledgement 发给客户端
    private void sendAcknowledgement(byte[] blockNumber, int portNumber) {

        // Create the ack packet

        /**
         * 04 00 ----------------------- | Opcode | Block | ----------------------- 2
         * bytes 2 bytes
         *
         */

        byte[] ackBytes = { (byte) 0, ACK, blockNumber[0], blockNumber[1] };
        DatagramPacket ackPacket = new DatagramPacket(ackBytes, ackBytes.length, IPAddress, CLIENT_PORT);

        try {
            serverSocket.send(ackPacket);
        } catch (IOException e) {
            sendError(0, "Error sending ack", IPAddress, CLIENT_PORT);
        }
    }

    // Receive the file packets sent from a client
    private ByteArrayOutputStream receivePackets(String fileName, String mode) throws IOException {

        // File IO variables
        ByteArrayOutputStream byteOutputStream = new ByteArrayOutputStream();
        // 0已经收到了 发第1块
        int block = 1;
        boolean packetsRemaining      = true;
        byte[] packetByteBuffer;
        DatagramPacket receivedPacket;

        // If there are packets to come, continue receiving
        while (packetsRemaining) {

            // Determine if there is sufficient diskspace, otherwise raise an error
            if (outputFileDir.getTotalSpace() < 516) {
                sendError(3, ERROR_3, IPAddress, CLIENT_PORT);
                return null;
            }
            // Receive the new data packet 516B
            packetByteBuffer = new byte[PACKET_SIZE];
            receivedPacket = new DatagramPacket(packetByteBuffer, packetByteBuffer.length);

            // Wait for a valid packet
            while (true) {
                try {
                    // 5s timeout
                    serverSocket.setSoTimeout(5000);
                    serverSocket.receive(receivedPacket);
                    if (receivedPacket.getPort() != CLIENT_PORT) {
                        sendError(5, ERROR_5, receivedPacket.getAddress(), receivedPacket.getPort());
                    } else {
                        break;
                    }
                    // Resend the ack if the socket times out
                } catch (SocketTimeoutException e) {
                    // 重发上一块的ack 高几位 低几位
                    byte blockMSB = (byte) (((block - 1) >>> 8) & 0xFF);
                    byte blockLSB = (byte) (((block - 1)) & 0xFF);
                    byte[] blockA = { blockMSB, blockLSB };
                    sendAcknowledgement(blockA, CLIENT_PORT);
                }
            }

            // Buffer the received data 新得到的包的字节数组
            ByteBuffer receiveBuffer = ByteBuffer.allocate(receivedPacket.getLength());
            receiveBuffer.put(packetByteBuffer, 0, receivedPacket.getLength());
            byte[] receivedData = receiveBuffer.array();

            // Format if required for netascii
            if (mode.equals("netascii")) {
                String temp = new String(receivedData);
                receivedData = Charset.forName("US-ASCII").encode(temp).array();
                temp = new String(receivedData);
            }

            // Determine data type
            byte opCode = receivedData[1];

            if (opCode == DATA) {

                // Check if final packet
                if (receivedPacket.getLength() < 516) {
                    packetsRemaining = false;
                    // Check if the file exists
                    File fileOut = new File(outputFileDir.getAbsoluteFile() + "\\" + fileName);
                    if (fileOut.exists()) {
                        sendError(6, ERROR_6, IPAddress, CLIENT_PORT);
                        return null;
                    }
                }

                // Get block Number
                byte[] temp1 = { (receivedData[2]), receivedData[3] };
                int number = (((receivedData[2] & 0xFF) << 8) + (receivedData[3] & 0xFF)) & 0xFFFF;

                // BUffer file
                DataOutputStream dataOut = new DataOutputStream(byteOutputStream);
                // 数据报的4-516为数据部分,写入输出流,头部不写
                dataOut.write(receivedData, 4, receivedData.length - 4);

                // Send ACK and increment block count 发送当前接受宝的确认报
                sendAcknowledgement(temp1, CLIENT_PORT);
                block++;
            }
            // Handle error responses
            else if (opCode == ERR) {
                packetByteBuffer[1] = (byte) 0;
                packetByteBuffer[3] = (byte) 0;
                byte[] zeroByteArray = { 0 };
                String zeroString = new String(zeroByteArray);
                String errorMessage = new String(packetByteBuffer);
                String[] splitMessage = errorMessage.split(zeroString);
                System.out.println("Error : " + splitMessage[4]);
                return null;
            } else {
                sendError(0, "Undefined error", IPAddress, CLIENT_PORT);
                return null;
            }
        }

        // Return byte output
        return byteOutputStream;
    }

    // Writes a received file to the server storage
    private void writeFile(ByteArrayOutputStream byteOutputStream, String fileName) {
        try {

            // Write if the file doesn't exist, otherwise raise an error
            File fileOut = new File(outputFileDir.getAbsoluteFile() + "\\" + fileName);
            if (fileOut.exists()) {
                sendError(6, ERROR_6, IPAddress, CLIENT_PORT);
                return;
            }
            // Write the file
            OutputStream outputStream = new FileOutputStream(fileOut);
            byteOutputStream.writeTo(outputStream);

            // Close the output streams
            byteOutputStream.close();
            outputStream.close();
        }
        // Handle file errors
        catch (FileAlreadyExistsException e) {
            sendError(6, ERROR_6, IPAddress, CLIENT_PORT);
            return;
        } catch (IOException e) {
            sendError(6, ERROR_6, IPAddress, CLIENT_PORT);
            return;
        }
    }

}

程序改进

数据包为512整数倍时,在服务器接收到最后一块数据块后发送ACK后即完成任务,但若此时ACK丢包,CLIENT端会一直超时等待重发最后一个空包,只能走服务端的WRQ/RRQ判断空包为非法操作回执。另可采用多线程来模拟多个客户端同时访问服务器的情况,以检测传输的可靠性。

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注

©2018-2024 Howell版权所有 备案号:冀ICP备19000576号