feat: added error handling

main
patchy oss 2 months ago
parent e6345c5ab3
commit f75051a040

@ -3,11 +3,15 @@ package main
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"io/fs"
"log"
"net"
"os"
"strings"
"syscall"
)
const (
@ -57,63 +61,76 @@ func (tftp *tftpServer) listenAndServe() {
log.Printf("error while reading packet: '%v'\n", err)
continue
}
err = tftp.handleConnection(addr, numRead, body)
if err != nil {
log.Printf("got some error: '%v'\n", err)
}
tftp.handleConnection(addr, numRead, body)
}
}
func (tftp *tftpServer) handleConnection(addr net.Addr, numRead int, body []byte) error {
func (tftp *tftpServer) handleConnection(addr net.Addr, numRead int, body []byte) {
cli, ok := tftp.connections[addr.String()]
if !ok {
cli = newClient(addr)
}
req, err := newRequest(numRead, body)
if err != nil {
return err
}
err := func() error {
req, err := newRequest(numRead, body)
if err != nil {
return err
}
err = tftp.handleRequest(cli, req)
if err != nil {
err = tftp.handleRequest(cli, req)
if err != nil {
return err
}
resp := newResponse(cli, req)
err = tftp.handleResponse(cli, resp)
if err != nil {
return err
}
_, err = tftp.sendResponse(cli, resp)
return err
}
}()
resp := newResponse(cli, req)
err = tftp.handleResponse(cli, resp)
if err != nil {
return err
tftpErr, ok := err.(*tftpError)
if !ok {
log.Printf("Got unexpected error: %v\n", err)
tftpErr = newTFTPError(ecNDEF, "Unexpected error.")
}
_, err = tftp.sendError(cli, tftpErr)
if err != nil {
panic(err)
}
}
_, err = tftp.sendResponse(cli, resp)
return err
}
func (tftp *tftpServer) handleRequest(cli *client, req *request) error {
// check for filename (rrq)
// check for filename (wrq)
// check for disk space
// check for file exist (wrq)
// access violation ?
// no such user ?
// checking for illegal operations
if req.opcode < opRRQ || req.opcode > opERROR {
return fmt.Errorf("Illegal operation!\n")
return newTFTPError(ecILL)
}
// checking for unknown client
_, clientExists := tftp.connections[cli.tid.String()]
if !clientExists && req.opcode != opRRQ && req.opcode != opWRQ {
return fmt.Errorf("Unknown client!\n")
return newTFTPError(ecUTID)
}
// TODO: handle this properly (probably will need to close 'connection')
if req.opcode == opERROR {
log.Printf("Got error from client: '%s' (%v)\n", req.errorMessage, req.number)
return nil
}
if !clientExists {
log.Printf("Got new client: %v\n", cli.tid.String())
err := cli.prepareFromRequest(req)
if err != nil {
return err
@ -121,9 +138,13 @@ func (tftp *tftpServer) handleRequest(cli *client, req *request) error {
tftp.connections[cli.tid.String()] = cli
}
// TODO: last data packet, close the client!
if req.opcode == opDATA {
_, err := io.Copy(cli.file, bytes.NewReader(req.body))
if err != nil {
if errors.Is(err, syscall.ENOSPC) {
err = newTFTPError(ecDSK)
}
return err
}
}
@ -138,6 +159,7 @@ func (tftp *tftpServer) handleResponse(cli *client, resp *response) error {
return err
}
if cli.bytesLeft <= 0 {
log.Printf("Client '%v' has received a file.\n", cli.tid.String())
cli.file.Close()
delete(tftp.connections, cli.tid.String())
}
@ -148,6 +170,11 @@ func (tftp *tftpServer) handleResponse(cli *client, resp *response) error {
return nil
}
func (tftp *tftpServer) sendError(cli *client, err *tftpError) (int, error) {
log.Println(err)
return tftp.sendResponse(cli, &response{opERROR, uint16(err.code), toCString(err.message.Error())})
}
func (tftp *tftpServer) sendResponse(cli *client, resp *response) (int, error) {
header := []byte{0x0, byte(resp.opcode), 0x0, 0x0}
binary.BigEndian.PutUint16(header[2:], resp.number)
@ -171,12 +198,24 @@ func (cli *client) prepareFromRequest(req *request) error {
var err error
var f *os.File
// TODO: clean path to filename
if req.opcode == opRRQ {
f, err = os.Open(req.filename)
} else {
if _, err := os.Stat(req.filename); !errors.Is(err, fs.ErrNotExist) {
return newTFTPError(ecFEX)
}
f, err = os.Create(req.filename)
}
if err != nil {
switch {
case errors.Is(err, fs.ErrNotExist):
err = newTFTPError(ecFNF)
case errors.Is(err, fs.ErrPermission):
err = newTFTPError(ecACV)
case errors.Is(err, syscall.ENOSPC):
err = newTFTPError(ecDSK)
}
return err
}
@ -226,7 +265,7 @@ func newRequest(numRead int, body []byte) (*request, error) {
return nil, err
}
if req.mode != "octet" {
return nil, fmt.Errorf("Incorrect mode '%v'. This server only supports 'octet' mode.\n", req.mode)
return nil, newTFTPError(ecNDEF, "Incorrect mode '%v'. This server supports only 'octet' mode.", req.mode)
}
req.body = req.body[n:]
@ -270,16 +309,30 @@ func newResponse(cli *client, req *request) *response {
return resp
}
type operation byte
type tftpError struct {
code errorCode
message error
}
const (
opUNK operation = iota
opRRQ
opWRQ
opDATA
opACK
opERROR
)
func newTFTPError(code errorCode, clientMessage ...string) *tftpError {
if code > ecNOUS {
code = ecNDEF
}
message := tftpErrors[code]
if code == ecNDEF {
message = errors.New(strings.Join(clientMessage, " "))
}
return &tftpError{
code: code,
message: message,
}
}
func (err *tftpError) Error() string {
return fmt.Sprintf("TFTP Error (%v): %v", err.code, err.message)
}
type errorCode uint16
@ -294,26 +347,27 @@ const (
ecNOUS
)
var errorMessages = map[errorCode]string{
ecNDEF: "",
ecFNF: "File not found.",
ecACV: "Access violation.",
ecDSK: "Disk full or allocation exceeded.",
ecILL: "Illegal TFTP operation.",
ecUTID: "Unknown transfer ID.",
ecFEX: "File already exists.",
ecNOUS: "No such user.",
var tftpErrors = [...]error{
ecNDEF: errors.New(""),
ecFNF: errors.New("File not found."),
ecACV: errors.New("Access violation."),
ecDSK: errors.New("Disk full or allocation exceeded."),
ecILL: errors.New("Illegal TFTP operation."),
ecUTID: errors.New("Unknown transfer ID."),
ecFEX: errors.New("File already exists."),
ecNOUS: errors.New("No such user."),
}
/*
func sendError(addr net.Addr, err errorCode, optMessage ...string) (int, error) {
message := errorMessages[err]
if err == ecNDEF {
message = strings.Join(optMessage, "\n")
}
return sendResponse(addr, &response{opERROR, uint16(err), toCString(message)})
}
*/
type operation byte
const (
opUNK operation = iota
opRRQ
opWRQ
opDATA
opACK
opERROR
)
func toCString(src string) []byte {
return append([]byte(src), 0x0)

Loading…
Cancel
Save