aboutsummaryrefslogtreecommitdiff
path: root/src/ast_opts.lua
blob: de4993292c0674fc4863b0ce1ff3ef6a6e3a589c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
--[[
	Optimizatoins for abstract syntax trees
]]
local msg = io.write
--A debugging function, a replacement for glua PrintTable
local function printtable(tbl, tabset)
    tabset = tabset or 0
    for k,v in pairs(tbl) do
        for i = 0,tabset do msg("\t") end
        msg(k .. ":")
        if type(v) == "table" then
            msg("\n")
            printtable(v, tabset + 1)
        else
            msg(tostring(v) .. "\n")
        end
    end
end

--A function to see if two ast's are equal-ish (does not compare position)
local function deepcompare(tbl1, tbl2)
	if type(tbl1) ~= type(tbl2) then return false end
    for k,v in pairs(tbl1) do
		print("Checking ", k, " from tbl1")
		if k == "pos" then goto cont end
        if type(v) == "table" then
			print("It is a table! going deeper")
            if not deepcompare(v,tbl2[k]) then
                return false
            end
        else
			print("Checking ", v , " against ", tbl2[k])
            if v ~= tbl2[k] then
                return false
            end
        end
		::cont::
    end
	return true
end

local opts = {}

--Optimization 1
--Folds things with an operator when the fold results in a smaller string
local foldables = {
	["add"] = function(a,b) return a + b end,
	["mul"] = function(a,b) return a * b end,
	["mod"] = function(a,b) return a % b end,
	["sub"] = function(a,b) return a - b end,
	--["div"] = function(a,b) return a / b end, division has the chance to give us really long strings!
}
opts[1] = function(ast)
	if ast.tag ~= "Op" then return false end
	local opname = ast[1]
	local func = foldables[opname]
	if ast[3] ~= nil and func ~= nil and ast[2].tag == "Number" and ast[3].tag == "Number" then
		ast.tag = "Number"
		ast[1] = func(ast[2][1],ast[3][1])
		for i = 2,#ast do
			ast[i] = nil
		end
		return true
	end
	return false
end

--Optimization 2
--Find places where we can replace calls with invokes.
opts[2] = function(ast)
	if ast.tag == "Call" and ast.pos == 160 then
		print("Ast:")
		printtable(ast)
		print("ast[1][1]")
		printtable(ast[1][1])
		print("ast[2]")
		printtable(ast[2])
		local dcr = deepcompare(ast[1][1],ast[2])
		print("Deepcompare:",dcr)
		--error("stopping")
	end
	if ast.tag == "Call" and deepcompare(ast[1][1][1], ast[2][1]) then
		print("Before correcting for invoke, ast is")
		printtable(ast)
		for i = 2,#ast[2] do
			ast[i] = ast[i+1]
		end
		ast.tag = "Invoke"
		ast[2] = ast[1][2]
		ast[1] = ast[1][1]
		print("After correcting for invoke, ast is")
		printtable(ast)

		--error("Call that should be invoke detected")
		return true
	end

end

return opts