Many improvements to lambdas and function calling

This commit is contained in:
John Wiegley 2012-03-08 00:44:27 -06:00
parent ae4ef7a88d
commit 4d01143400
4 changed files with 223 additions and 132 deletions

340
src/op.cc
View file

@ -178,7 +178,10 @@ expr_t::ptr_op_t expr_t::op_t::compile(scope_t& scope, const int depth,
ptr_op_t varname = sym->kind == O_CONS ? sym->left() : sym;
if (! varname->is_ident()) {
throw_(calc_error, _("Invalid function or lambda parameter"));
std::ostringstream buf;
varname->dump(buf, 0);
throw_(calc_error,
_("Invalid function or lambda parameter: %1") << buf.str());
} else {
DEBUG("expr.compile",
"Defining function parameter " << varname->as_ident());
@ -224,6 +227,23 @@ expr_t::ptr_op_t expr_t::op_t::compile(scope_t& scope, const int depth,
return result;
}
namespace {
expr_t::ptr_op_t lookup_ident(expr_t::ptr_op_t op, scope_t& scope)
{
expr_t::ptr_op_t def = op->left();
// If no definition was pre-compiled for this identifier, look it up
// in the current scope.
if (! def || def->kind == expr_t::op_t::PLUG) {
DEBUG("scope.symbols", "Looking for IDENT '" << op->as_ident() << "'");
def = scope.lookup(symbol_t::FUNCTION, op->as_ident());
}
if (! def)
throw_(calc_error, _("Unknown identifier '%1'") << op->as_ident());
return def;
}
}
value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth)
{
try {
@ -248,23 +268,14 @@ value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth)
result = NULL_VALUE;
break;
case IDENT: {
ptr_op_t definition = left();
// If no definition was pre-compiled for this identifier, look it up
// in the current scope.
if (! definition || definition->kind == PLUG) {
DEBUG("scope.symbols", "Looking for IDENT '" << as_ident() << "'");
definition = scope.lookup(symbol_t::FUNCTION, as_ident());
case IDENT:
if (ptr_op_t definition = lookup_ident(this, scope)) {
// Evaluating an identifier is the same as calling its definition
// directly
result = definition->calc(scope, locus, depth + 1);
check_type_context(scope, result);
}
if (! definition)
throw_(calc_error, _("Unknown identifier '%1'") << as_ident());
// Evaluating an identifier is the same as calling its definition
// directly
result = definition->calc(scope, locus, depth + 1);
check_type_context(scope, result);
break;
}
case FUNCTION: {
// Evaluating a FUNCTION is the same as calling it directly; this
@ -302,81 +313,14 @@ value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth)
break;
}
case O_CALL: {
ptr_op_t func = left();
string name;
if (func->is_ident()) {
name = func->as_ident();
func = func->left();
if (! func)
func = scope.lookup(symbol_t::FUNCTION, name);
if (! func)
throw_(calc_error, _("Calling unknown function '%1'") << name);
} else {
name = "<lambda>";
}
call_scope_t call_args(scope, locus, depth + 1);
if (has_right())
call_args.set_args(split_cons_expr(right()));
try {
if (func->is_function())
result = func->as_function()(call_args);
else
result = func->calc(call_args, locus, depth + 1);
}
catch (const std::exception&) {
add_error_context(_("While calling function '%1':" << name));
throw;
}
case O_CALL:
result = calc_call(scope, locus, depth);
check_type_context(scope, result);
break;
}
case O_LAMBDA: {
call_scope_t& call_args(find_scope<call_scope_t>(scope, true));
std::size_t args_count(call_args.size());
std::size_t args_index(0);
symbol_scope_t args_scope(*scope_t::empty_scope);
for (ptr_op_t sym = left();
sym;
sym = sym->has_right() ? sym->right() : NULL) {
ptr_op_t varname = sym->kind == O_CONS ? sym->left() : sym;
if (! varname->is_ident()) {
throw_(calc_error, _("Invalid function definition"));
}
else if (args_index == args_count) {
DEBUG("expr.compile", "Defining function argument as null: "
<< varname->as_ident());
args_scope.define(symbol_t::FUNCTION, varname->as_ident(),
wrap_value(NULL_VALUE));
}
else {
DEBUG("expr.compile", "Defining function argument from call_args: "
<< varname->as_ident());
args_scope.define(symbol_t::FUNCTION, varname->as_ident(),
wrap_value(call_args[args_index++]));
}
}
if (args_index < args_count)
throw_(calc_error,
_("Too few arguments in function call (saw %1, wanted %2)")
<< args_count << args_index);
if (right()->is_scope()) {
bind_scope_t outer_scope(scope, *right()->as_scope());
bind_scope_t bound_scope(outer_scope, args_scope);
result = right()->left()->calc(bound_scope, locus, depth + 1);
} else {
result = right()->calc(args_scope, locus, depth + 1);
}
case O_LAMBDA:
result = expr_value(this);
break;
}
case O_MATCH:
result = (right()->calc(scope, locus, depth + 1).as_mask()
@ -457,51 +401,12 @@ value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth)
break;
case O_CONS:
result = left()->calc(scope, locus, depth + 1);
if (has_right()) {
value_t temp;
temp.push_back(result);
ptr_op_t next = right();
while (next) {
ptr_op_t value_op;
if (next->kind == O_CONS) {
value_op = next->left();
next = next->has_right() ? next->right() : NULL;
} else {
value_op = next;
next = NULL;
}
temp.push_back(value_op->calc(scope, locus, depth + 1));
}
result = temp;
}
result = calc_cons(scope, locus, depth);
break;
case O_SEQ: {
// An O_SEQ is very similar to an O_CONS except that only the last
// result value in the series is kept. O_CONS builds up a list.
//
// Another feature of O_SEQ is that it pushes a new symbol scope
// onto the stack. We evaluate the left side here to catch any
// side-effects, such as definitions in the case of 'x = 1; x'.
result = left()->calc(scope, locus, depth + 1);
if (has_right()) {
ptr_op_t next = right();
while (next) {
ptr_op_t value_op;
if (next->kind == O_SEQ) {
value_op = next->left();
next = next->right();
} else {
value_op = next;
next = NULL;
}
result = value_op->calc(scope, locus, depth + 1);
}
}
case O_SEQ:
result = calc_seq(scope, locus, depth);
break;
}
default:
throw_(calc_error, _("Unexpected expr node '%1'") << op_context(this));
@ -527,6 +432,183 @@ value_t expr_t::op_t::calc(scope_t& scope, ptr_op_t * locus, const int depth)
}
}
namespace {
expr_t::ptr_op_t find_definition(expr_t::ptr_op_t op, scope_t& scope,
expr_t::ptr_op_t * locus, const int depth,
int recursion_depth = 0)
{
// If the object we are apply call notation to is a FUNCTION value
// or a O_LAMBDA expression, then this is the object we want to
// call.
if (op->is_function() || op->kind == expr_t::op_t::O_LAMBDA)
return op;
if (recursion_depth > 256)
throw_(value_error, _("Function recursion_depth too deep (> 256)"));
// If it's an identifier, look up its definition and see if it's a
// function.
if (op->is_ident())
return find_definition(lookup_ident(op, scope), scope,
locus, depth, recursion_depth + 1);
// Value objects might be callable if they contain an expression.
if (op->is_value()) {
value_t def(op->as_value());
if (is_expr(def))
return find_definition(as_expr(def), scope, locus, depth,
recursion_depth + 1);
else
throw_(value_error, _("Cannot call %1 as a function") << def.label());
}
// Resolve ordinary expressions.
return find_definition(expr_t::op_t::wrap_value(op->calc(scope, locus,
depth + 1)),
scope, locus, depth + 1, recursion_depth + 1);
}
value_t call_lambda(expr_t::ptr_op_t func, scope_t& scope,
call_scope_t& call_args, expr_t::ptr_op_t * locus,
const int depth)
{
std::size_t args_index(0);
std::size_t args_count(call_args.size());
symbol_scope_t args_scope(*scope_t::empty_scope);
for (expr_t::ptr_op_t sym = func->left();
sym;
sym = sym->has_right() ? sym->right() : NULL) {
expr_t::ptr_op_t varname =
sym->kind == expr_t::op_t::O_CONS ? sym->left() : sym;
if (! varname->is_ident()) {
throw_(calc_error, _("Invalid function definition"));
}
else if (args_index == args_count) {
DEBUG("expr.calc", "Defining function argument as null: "
<< varname->as_ident());
args_scope.define(symbol_t::FUNCTION, varname->as_ident(),
expr_t::op_t::wrap_value(NULL_VALUE));
}
else {
DEBUG("expr.calc", "Defining function argument from call_args: "
<< varname->as_ident());
args_scope.define(symbol_t::FUNCTION, varname->as_ident(),
expr_t::op_t::wrap_value(call_args[args_index++]));
}
}
if (args_index < args_count)
throw_(calc_error,
_("Too few arguments in function call (saw %1, wanted %2)")
<< args_count << args_index);
if (func->right()->is_scope()) {
bind_scope_t outer_scope(scope, *func->right()->as_scope());
bind_scope_t bound_scope(outer_scope, args_scope);
return func->right()->left()->calc(bound_scope, locus, depth + 1);
} else {
return func->right()->calc(args_scope, locus, depth + 1);
}
}
}
value_t expr_t::op_t::call(const value_t& args, scope_t& scope,
ptr_op_t * locus, const int depth)
{
call_scope_t call_args(scope, locus, depth + 1);
call_args.set_args(args);
if (is_function())
return as_function()(call_args);
else if (kind == O_LAMBDA)
return call_lambda(this, scope, call_args, locus, depth);
else
return find_definition(this, scope, locus, depth)
->calc(call_args, locus, depth);
}
value_t expr_t::op_t::calc_call(scope_t& scope, ptr_op_t * locus,
const int depth)
{
ptr_op_t func = left();
string name = func->is_ident() ? func->as_ident() : "<value expr>";
try {
func = find_definition(func, scope, locus, depth);
call_scope_t call_args(scope, locus, depth + 1);
if (has_right())
call_args.set_args(split_cons_expr(right()));
if (func->is_function()) {
return func->as_function()(call_args);
} else {
assert(func->kind == O_LAMBDA);
return call_lambda(func, scope, call_args, locus, depth);
}
}
catch (const std::exception&) {
add_error_context(_("While calling function '%1':" << name));
throw;
}
}
value_t expr_t::op_t::calc_cons(scope_t& scope, ptr_op_t * locus,
const int depth)
{
value_t result = left()->calc(scope, locus, depth + 1);
if (has_right()) {
value_t temp;
temp.push_back(result);
ptr_op_t next = right();
while (next) {
ptr_op_t value_op;
if (next->kind == O_CONS) {
value_op = next->left();
next = next->has_right() ? next->right() : NULL;
} else {
value_op = next;
next = NULL;
}
temp.push_back(value_op->calc(scope, locus, depth + 1));
}
result = temp;
}
return result;
}
value_t expr_t::op_t::calc_seq(scope_t& scope, ptr_op_t * locus,
const int depth)
{
// An O_SEQ is very similar to an O_CONS except that only the last
// result value in the series is kept. O_CONS builds up a list.
//
// Another feature of O_SEQ is that it pushes a new symbol scope onto
// the stack. We evaluate the left side here to catch any
// side-effects, such as definitions in the case of 'x = 1; x'.
value_t result = left()->calc(scope, locus, depth + 1);
if (has_right()) {
ptr_op_t next = right();
while (next) {
ptr_op_t value_op;
if (next->kind == O_SEQ) {
value_op = next->left();
next = next->right();
} else {
value_op = next;
next = NULL;
}
result = value_op->calc(scope, locus, depth + 1);
}
}
return result;
}
namespace {
bool print_cons(std::ostream& out, const expr_t::const_ptr_op_t op,
const expr_t::op_t::context_t& context)

View file

@ -280,6 +280,9 @@ public:
value_t calc(scope_t& scope, ptr_op_t * locus = NULL,
const int depth = 0);
value_t call(const value_t& args, scope_t& scope,
ptr_op_t * locus = NULL, const int depth = 0);
struct context_t
{
ptr_op_t expr_op;
@ -307,6 +310,11 @@ public:
static ptr_op_t wrap_functor(expr_t::func_t fobj);
static ptr_op_t wrap_scope(shared_ptr<scope_t> sobj);
private:
value_t calc_call(scope_t& scope, ptr_op_t * locus, const int depth);
value_t calc_cons(scope_t& scope, ptr_op_t * locus, const int depth);
value_t calc_seq(scope_t& scope, ptr_op_t * locus, const int depth);
#if defined(HAVE_BOOST_SERIALIZATION)
private:
/** Serialization. */

View file

@ -87,9 +87,6 @@ expr_t::parser_t::parse_call_expr(std::istream& in,
node->set_left(prev);
push_token(tok); // let the parser see the '(' again
node->set_right(parse_value_expr(in, tflags.plus_flags(PARSE_SINGLE)));
if (! node->right())
throw_(parse_error,
_("%1 operator not followed by argument") << tok.symbol);
} else {
push_token(tok);
break;

View file

@ -15,3 +15,7 @@ test eval 'foo = x, y, z -> print(x, y, z); foo(1, 2, 3)'
123
1
end test
test eval 'foo(x,y)=y(1, 2, 3);foo(amount_expr, (s,d,t -> t))'
3
end test