aboutsummaryrefslogtreecommitdiff
path: root/src/ast_opts.lua
blob: 060c0ac764e1a9821ad13c692448b529fb0da970 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
--[[
	Optimizatoins for abstract syntax trees
]]
local msg = io.write

--A debugging function, a replacement for glua PrintTable
local function printtable(tbl, tabset)
    tabset = tabset or 0
    for k,v in pairs(tbl) do
        for i = 0,tabset do msg("\t") end
        msg(k .. ":")
        if type(v) == "table" then
            msg("\n")
            printtable(v, tabset + 1)
        else
            msg(tostring(v) .. "\n")
        end
    end
end

--A function to see if two tables are equal-ish
local function deepcompare(tbl1, tbl2)
	if type(tbl1) ~= type(tbl2) then return false end
    for k,v in pairs(tbl1) do
		print("Checking ", k, " from tbl1")
		if k == "pos" then goto cont end
        if type(v) == "table" then
			print("It is a table! going deeper")
            if not deepcompare(v,tbl2[k]) then
                return false
            end
        else
			print("Checking ", v , " against ", tbl2[k])
            if v ~= tbl2[k] then
                return false
            end
        end
		::cont::
    end
	return true
end

--Makes sure 2 things are refrenceing the same index
local function indexcompare(tbl1,tbl2)
	if type(tbl1) ~= "table" or type(tbl2) ~= "table" then return false end
	--print("indexcompare is checking ",tbl1,tbl2)
	--printtable(tbl1)
	--print("is the same as")
	--printtable(tbl2)
	if tbl1.tag == "Id" and tbl2.tag == "Id" then
		return tbl1[1] == tbl2[1]
	elseif tbl1.tag == "Index" and tbl2.tag == "Index" then
		return indexcompare(tbl1[1],tbl2[1]) and indexcompare(tbl1[2],tbl2[2])
	elseif tbl1.tag == "String" and tbl2.tag == "String" then
		return tbl1[1] == tbl2[1]
	else
		return false
	end
end

local opts = {}

--Optimization 1
--Folds things with an operator when we have numbers on both sides, and the fold results in a smaller string
local foldables = {
	["add"] = function(a,b) return a + b end,
	["mul"] = function(a,b) return a * b end,
	["mod"] = function(a,b) return a % b end,
	["sub"] = function(a,b) return a - b end,
	--["div"] = function(a,b) return a / b end,-- division has the chance to give us really long strings!
}
opts[1] = function(ast)
	if ast.tag ~= "Op" then return false end
	local opname = ast[1]
	local func = foldables[opname]
	if ast[3] ~= nil and func ~= nil and ast[2].tag == "Number" and ast[3].tag == "Number" then
		ast.tag = "Number"
		ast[1] = func(ast[2][1],ast[3][1])
		for i = 2,#ast do
			ast[i] = nil
		end
		return true
	end
	return false
end

--Optimization 2
--Find places where we can replace calls with invokes.
--Lua provides syntax sugar where 
--	table.method(table,arg1,arg2,...) 
--is the same as
--	table:method(arg1,arg2,...)
opts[2] = function(ast)
	--[[for debugging
	if ast.tag == "Call" then
		--print("Ast:")
		--printtable(ast)
		--print("ast[1][1]")
		--printtable(ast[1][1])
		--print("ast[2]")
		--printtable(ast[2])
		--local dcr = indexcompare(ast[1][1],ast[2])
		--print("indexcompare:",dcr)
		--error("stopping")
	end
	]]

	--The ands are to make sure we short circut before indexing a nil
	if ast.tag == "Call" and ast[1][1] and ast[2] and indexcompare(ast[1][1][1], ast[2][1]) then
		--print("Before correcting for invoke, ast is")
		--printtable(ast)
		for i = 2,#ast[2] do
			ast[i] = ast[i + 1]
		end
		ast.tag = "Invoke"
		ast[2] = ast[1][2]
		ast[1] = ast[1][1]

		return true
	end
	return false
end

--Find places where we have multiple variables declared, but not initalized
--and turn it into a 1-liner
--so local a local b local c
--becomes
--local a,b,c
--TODO: This can be done faster and in 1 pass with a little extra complexity
opts[3] = function(ast)
	if ast.tag == "Block" then
		for i = 1,#ast do
			if ast[i].tag == "Local" and next(ast[i][2]) == nil then
				local cursor = ast[i]
				local r
				i = i + 1
				while ast[i].tag == "Local" and next(ast[i][2]) == nil do
					table.insert(cursor[1],ast[i][1][1])
					table.remove(ast,i)
					i = i + 1
					r = true
				end
				if r then return true end
			end
		end
		return false
	end
	return false
end

return opts