Skip to content

Commit

Permalink
Nullary tagged union constructors no longer need parens
Browse files Browse the repository at this point in the history
For example, now we can write types.T.Any instead of types.T.Any().
It should have been like this from the beginning, but we couldn't
because some of our code was using ir.Cmd objects as table keys.
(This was just fixed in a recent PR)
  • Loading branch information
hugomg committed Jul 23, 2024
1 parent a3fbfdd commit 76a1e1a
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 125 deletions.
84 changes: 42 additions & 42 deletions spec/types_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,41 @@ local types = require "pallene.types"
describe("Pallene types", function()

it("pretty-prints types", function()
assert.same("{ integer }", types.tostring(types.T.Array(types.T.Integer())))
assert.same("{ integer }", types.tostring(types.T.Array(types.T.Integer)))
assert.same("{ x: float, y: float }", types.tostring(
types.T.Table({x = types.T.Float(), y = types.T.Float()})))
types.T.Table({x = types.T.Float, y = types.T.Float})))
end)

it("is_gc works", function()
assert.falsy(types.is_gc(types.T.Integer()))
assert.truthy(types.is_gc(types.T.String()))
assert.truthy(types.is_gc(types.T.Array(types.T.Integer())))
assert.truthy(types.is_gc(types.T.Table({x = types.T.Float()})))
assert.falsy(types.is_gc(types.T.Integer))
assert.truthy(types.is_gc(types.T.String))
assert.truthy(types.is_gc(types.T.Array(types.T.Integer)))
assert.truthy(types.is_gc(types.T.Table({x = types.T.Float})))
assert.truthy(types.is_gc(types.T.Function({}, {})))
end)

describe("equality", function()

it("works for primitive types", function()
assert.truthy(types.equals(types.T.Integer(), types.T.Integer()))
assert.falsy(types.equals(types.T.Integer(), types.T.String()))
assert.truthy(types.equals(types.T.Integer, types.T.Integer))
assert.falsy(types.equals(types.T.Integer, types.T.String))
end)

it("is true for two identical tables", function()
local t1 = types.T.Table({
y = types.T.Integer(), x = types.T.Integer()})
y = types.T.Integer, x = types.T.Integer})
local t2 = types.T.Table({
x = types.T.Integer(), y = types.T.Integer()})
x = types.T.Integer, y = types.T.Integer})
assert.truthy(types.equals(t1, t2))
assert.truthy(types.equals(t2, t1))
end)

it("is false for tables with different number of fields", function()
local t1 = types.T.Table({x = types.T.Integer()})
local t2 = types.T.Table({x = types.T.Integer(),
y = types.T.Integer()})
local t3 = types.T.Table({x = types.T.Integer(),
y = types.T.Integer(), z = types.T.Integer()})
local t1 = types.T.Table({x = types.T.Integer})
local t2 = types.T.Table({x = types.T.Integer,
y = types.T.Integer})
local t3 = types.T.Table({x = types.T.Integer,
y = types.T.Integer, z = types.T.Integer})
assert.falsy(types.equals(t1, t2))
assert.falsy(types.equals(t2, t1))
assert.falsy(types.equals(t2, t3))
Expand All @@ -52,39 +52,39 @@ describe("Pallene types", function()
end)

it("is false for tables with different field names", function()
local t1 = types.T.Table({x = types.T.Integer()})
local t2 = types.T.Table({y = types.T.Integer()})
local t1 = types.T.Table({x = types.T.Integer})
local t2 = types.T.Table({y = types.T.Integer})
assert.falsy(types.equals(t1, t2))
assert.falsy(types.equals(t2, t1))
end)

it("is false for tables with different field types", function()
local t1 = types.T.Table({x = types.T.Integer()})
local t2 = types.T.Table({x = types.T.Float()})
local t1 = types.T.Table({x = types.T.Integer})
local t2 = types.T.Table({x = types.T.Float})
assert.falsy(types.equals(t1, t2))
assert.falsy(types.equals(t2, t1))
end)

it("is true for identical functions", function()
local f1 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Boolean()})
local f2 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Boolean()})
local f1 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Boolean})
local f2 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Boolean})
assert.truthy(types.equals(f1, f2))
end)

it("is false for functions with different input types", function()
local f1 = types.T.Function({types.T.String(), types.T.Boolean()}, {types.T.Boolean()})
local f2 = types.T.Function({types.T.Integer(), types.T.Integer()}, {types.T.Boolean()})
local f1 = types.T.Function({types.T.String, types.T.Boolean}, {types.T.Boolean})
local f2 = types.T.Function({types.T.Integer, types.T.Integer}, {types.T.Boolean})
assert.falsy(types.equals(f1, f2))
end)

it("is false for functions with different output types", function()
local f1 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Boolean()})
local f2 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Integer()})
local f1 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Boolean})
local f2 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Integer})
assert.falsy(types.equals(f1, f2))
end)

it("is false for functions with different input arity", function()
local s = types.T.String()
local s = types.T.String
local f1 = types.T.Function({}, {s})
local f2 = types.T.Function({s}, {s})
local f3 = types.T.Function({s, s}, {s})
Expand All @@ -97,7 +97,7 @@ describe("Pallene types", function()
end)

it("is false for functions with different output arity", function()
local s = types.T.String()
local s = types.T.String
local f1 = types.T.Function({s}, {})
local f2 = types.T.Function({s}, {s})
local f3 = types.T.Function({s}, {s, s})
Expand All @@ -123,42 +123,42 @@ describe("Pallene types", function()

describe("consistency", function()
it("allows 'any' on either side", function()
assert.truthy(types.consistent(types.T.Any(), types.T.Any()))
assert.truthy(types.consistent(types.T.Any(), types.T.Integer()))
assert.truthy(types.consistent(types.T.Integer(), types.T.Any()))
assert.truthy(types.consistent(types.T.Any, types.T.Any))
assert.truthy(types.consistent(types.T.Any, types.T.Integer))
assert.truthy(types.consistent(types.T.Integer, types.T.Any))
end)

it("allows types with same tag", function()
assert.truthy(types.consistent(
types.T.Integer(),
types.T.Integer()
types.T.Integer,
types.T.Integer
))

assert.truthy(types.consistent(
types.T.Array(types.T.Integer()),
types.T.Array(types.T.Integer())
types.T.Array(types.T.Integer),
types.T.Array(types.T.Integer)
))

assert.truthy(types.consistent(
types.T.Array(types.T.Integer()),
types.T.Array(types.T.String())
types.T.Array(types.T.Integer),
types.T.Array(types.T.String)
))

assert.truthy(types.consistent(
types.T.Function({types.T.Integer()}, {types.T.Integer()}),
types.T.Function({types.T.String(), types.T.String()}, {})
types.T.Function({types.T.Integer}, {types.T.Integer}),
types.T.Function({types.T.String, types.T.String}, {})
))
end)

it("forbids different tags", function()
assert.falsy(types.consistent(
types.T.Integer(),
types.T.String()
types.T.Integer,
types.T.String
))

assert.falsy(types.consistent(
types.T.Array(types.T.Integer()),
types.T.Function({types.T.Integer()},{types.T.Integer()})
types.T.Array(types.T.Integer),
types.T.Function({types.T.Integer},{types.T.Integer})
))
end)
end)
Expand Down
42 changes: 21 additions & 21 deletions src/pallene/builtins.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,38 @@ local builtins = {}

-- TODO: It will be easier to read this is we could write down the types using the normal grammar

local ipairs_itertype = T.Function({T.Any(), T.Any()}, {T.Any(), T.Any()})
local ipairs_itertype = T.Function({T.Any, T.Any}, {T.Any, T.Any})

builtins.functions = {
type = T.Function({ T.Any() }, { T.String() }),
tostring = T.Function({ T.Any() }, { T.String() }),
ipairs = T.Function({T.Array(T.Any())}, {ipairs_itertype, T.Any(), T.Any()})
type = T.Function({ T.Any }, { T.String }),
tostring = T.Function({ T.Any }, { T.String }),
ipairs = T.Function({T.Array(T.Any)}, {ipairs_itertype, T.Any, T.Any})
}

builtins.modules = {
io = {
write = T.Function({ T.String() }, {}),
write = T.Function({ T.String }, {}),
},
math = {
abs = T.Function({ T.Float() }, { T.Float() }),
ceil = T.Function({ T.Float() }, { T.Integer() }),
floor = T.Function({ T.Float() }, { T.Integer() }),
fmod = T.Function({ T.Float(), T.Float() }, { T.Float() }),
exp = T.Function({ T.Float() }, { T.Float() }),
ln = T.Function({ T.Float() }, { T.Float() }),
log = T.Function({ T.Float(), T.Float() }, { T.Float() }),
modf = T.Function({ T.Float() }, { T.Integer(), T.Float() }),
pow = T.Function({ T.Float(), T.Float() }, { T.Float() }),
sqrt = T.Function({ T.Float() }, { T.Float() }),
abs = T.Function({ T.Float }, { T.Float }),
ceil = T.Function({ T.Float }, { T.Integer }),
floor = T.Function({ T.Float }, { T.Integer }),
fmod = T.Function({ T.Float, T.Float }, { T.Float }),
exp = T.Function({ T.Float }, { T.Float }),
ln = T.Function({ T.Float }, { T.Float }),
log = T.Function({ T.Float, T.Float }, { T.Float }),
modf = T.Function({ T.Float }, { T.Integer, T.Float }),
pow = T.Function({ T.Float, T.Float }, { T.Float }),
sqrt = T.Function({ T.Float }, { T.Float }),
-- constant numbers
huge = T.Float(),
mininteger = T.Integer(),
maxinteger = T.Integer(),
pi = T.Float(),
huge = T.Float,
mininteger = T.Integer,
maxinteger = T.Integer,
pi = T.Float,
},
string = {
char = T.Function({ T.Integer() }, { T.String() }),
sub = T.Function({ T.String(), T.Integer(), T.Integer() }, { T.String() }),
char = T.Function({ T.Integer }, { T.String }),
sub = T.Function({ T.String, T.Integer, T.Integer }, { T.String }),
},
}

Expand Down
8 changes: 4 additions & 4 deletions src/pallene/coder.lua
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ function Coder:c_value(value)
return C.float(value.value)
elseif tag == "ir.Value.String" then
local str = value.value
return lua_value(types.T.String(), self:string_upvalue_slot(str))
return lua_value(types.T.String, self:string_upvalue_slot(str))
elseif tag == "ir.Value.LocalVar" then
return self:c_var(value.id)
elseif tag == "ir.Value.Upvalue" then
Expand Down Expand Up @@ -730,8 +730,8 @@ function Coder:init_upvalues()

-- If we are using tracebacks
if self.flags.use_traceback then
table.insert(self.constants, coder.Constant.DebugUserdata())
table.insert(self.constants, coder.Constant.DebugMetatable())
table.insert(self.constants, coder.Constant.DebugUserdata)
table.insert(self.constants, coder.Constant.DebugMetatable)
end

-- Metatables
Expand Down Expand Up @@ -1353,7 +1353,7 @@ gen_cmd["SetTable"] = function(self, args)
tab = tab,
key = key,
val = val,
init_keyv = set_stack_slot(types.T.String(), "&keyv", key),
init_keyv = set_stack_slot(types.T.String, "&keyv", key),
init_valv = set_stack_slot(src_typ, "&valv", val),
-- Here we use set_stack_slot slot on a heap object, because
-- we call the barrier by hand outside the if statement.
Expand Down
30 changes: 18 additions & 12 deletions src/pallene/tagged_union.lua
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,32 @@ local function make_tag(mod_name, type_name, cons_name)
end

-- Create a tagged union constructor
-- @param module Module table where the type is being defined
-- @param mod_table Module table where the type is being defined
-- @param mod_name Name of the module
-- @param type_name Name of the type
-- @param constructors Name of the constructor => fields of the record
local function define_union(mod_table, mod_name, type_name, constructors)
mod_table[type_name] = {}
for cons_name, fields in pairs(constructors) do
local tag = make_tag(mod_name, type_name, cons_name)
local function cons(...)
local args = table.pack(...)
if args.n ~= #fields then
error(string.format(
"wrong number of arguments for %s. Expected %d but received %d.",
cons_name, #fields, args.n))
end
local node = { _tag = tag }
for i, field in ipairs(fields) do
node[field] = args[i]

local cons
if #fields == 0 then
cons = { _tag = tag }
else
cons = function(...)
local args = table.pack(...)
if args.n ~= #fields then
error(string.format(
"wrong number of arguments for %s. Expected %d but received %d.",
cons_name, #fields, args.n))
end
local node = { _tag = tag }
for i, field in ipairs(fields) do
node[field] = args[i]
end
return node
end
return node
end
mod_table[type_name][cons_name] = cons
end
Expand Down
Loading

0 comments on commit 76a1e1a

Please sign in to comment.