1 -----------------------------------------------------------------------------
2 -- TFTP support for the Lua language
5 -----------------------------------------------------------------------------
7 -----------------------------------------------------------------------------
9 -----------------------------------------------------------------------------
11 local table = require("table")
12 local math
= require("math")
13 local string = require("string")
14 local socket
= require("socket")
15 local ltn12
= require("ltn12")
16 local url
= require("socket.url")
19 -----------------------------------------------------------------------------
21 -----------------------------------------------------------------------------
22 local char
= string.char
23 local byte
= string.byte
31 local OP_INV
= {"RRQ", "WRQ", "DATA", "ACK", "ERROR"}
33 -----------------------------------------------------------------------------
34 -- Packet creation functions
35 -----------------------------------------------------------------------------
36 local function RRQ(source
, mode
)
37 return char(0, OP_RRQ
) .. source
.. char(0) .. mode
.. char(0)
40 local function WRQ(source
, mode
)
41 return char(0, OP_RRQ
) .. source
.. char(0) .. mode
.. char(0)
44 local function ACK(block
)
46 low
= math
.mod(block
, 256)
47 high
= (block
- low
)/256
48 return char(0, OP_ACK
, high
, low
)
51 local function get_OP(dgram
)
52 local op
= byte(dgram
, 1)*256 + byte(dgram
, 2)
56 -----------------------------------------------------------------------------
57 -- Packet analysis functions
58 -----------------------------------------------------------------------------
59 local function split_DATA(dgram
)
60 local block
= byte(dgram
, 3)*256 + byte(dgram
, 4)
61 local data
= string.sub(dgram
, 5)
65 local function get_ERROR(dgram
)
66 local code
= byte(dgram
, 3)*256 + byte(dgram
, 4)
68 _
,_
, msg
= string.find(dgram
, "(.*)\000", 5)
69 return string.format("error code %d: %s", code
, msg
)
72 -----------------------------------------------------------------------------
74 -----------------------------------------------------------------------------
75 local function tget(gett
)
76 local retries
, dgram
, sent
, datahost
, dataport
, code
78 socket
.try(gett
.host
, "missing host")
79 local con
= socket
.try(socket
.udp())
80 local try
= socket
.newtry(function() con
:close() end)
81 -- convert from name to ip if needed
82 gett
.host
= try(socket
.dns
.toip(gett
.host
))
84 -- first packet gives data host/port to be used for data transfers
85 local path
= string.gsub(gett
.path
or "", "^/", "")
86 path
= url
.unescape(path
)
89 sent
= try(con
:sendto(RRQ(path
, "octet"), gett
.host
, gett
.port
))
90 dgram
, datahost
, dataport
= con
:receivefrom()
92 until dgram
or datahost
~= "timeout" or retries
> 5
94 -- associate socket with data host/port
95 try(con
:setpeername(datahost
, dataport
))
97 local sink
= gett
.sink
or ltn12
.sink
.null()
98 -- process all data packets
102 try(code
~= OP_ERROR
, get_ERROR(dgram
))
103 try(code
== OP_DATA
, "unhandled opcode " .. code
)
104 -- get data packet parts
105 local block
, data
= split_DATA(dgram
)
106 -- if not repeated, write
107 if block
== last
+1 then
111 -- last packet brings less than 512 bytes of data
112 if string.len(data
) < 512 then
113 try(con
:send(ACK(block
)))
118 -- get the next packet
121 sent
= try(con
:send(ACK(last
)))
122 dgram
, err
= con
:receive()
123 retries
= retries
+ 1
124 until dgram
or err
~= "timeout" or retries
> 5
135 local function parse(u
)
136 local t
= socket
.try(url
.parse(u
, default
))
137 socket
.try(t
.scheme
== "tftp", "invalid scheme '" .. t
.scheme
.. "'")
138 socket
.try(t
.host
, "invalid host")
142 local function sget(u
)
143 local gett
= parse(u
)
145 gett
.sink
= ltn12
.sink
.table(t
)
147 return table.concat(t
)
150 get
= socket
.protect(function(gett
)
151 if base
.type(gett
) == "string" then return sget(gett
)
152 else return tget(gett
) end