Skip to content

Commit

Permalink
Merge pull request #2094 from anutosh491/GSoC_PR6
Browse files Browse the repository at this point in the history
Added support for symbolic elementary functions
  • Loading branch information
certik authored Jul 4, 2023
2 parents 5334944 + 486db54 commit 39e1a26
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 63 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ RUN(NAME symbolics_02 LABELS cpython_sym c_sym)
RUN(NAME symbolics_03 LABELS cpython_sym c_sym)
RUN(NAME symbolics_04 LABELS cpython_sym c_sym)
RUN(NAME symbolics_05 LABELS cpython_sym c_sym)
RUN(NAME symbolics_06 LABELS cpython_sym c_sym)

RUN(NAME sizeof_01 LABELS llvm c
EXTRAFILES sizeof_01b.c)
Expand Down
37 changes: 37 additions & 0 deletions integration_tests/symbolics_06.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from sympy import Symbol, sin, cos, exp, log, Abs, pi, diff
from lpython import S

def test_elementary_functions():

# test sin, cos
x: S = Symbol('x')
assert(sin(pi) == S(0))
assert(sin(pi/S(2)) == S(1))
assert(sin(S(2)*pi) == S(0))
assert(cos(pi) == S(-1))
assert(cos(pi/S(2)) == S(0))
assert(cos(S(2)*pi) == S(1))
assert(diff(sin(x), x) == cos(x))
assert(diff(cos(x), x) == S(-1)*sin(x))

# test exp, log
assert(exp(S(0)) == S(1))
assert(log(S(1)) == S(0))
assert(diff(exp(x), x) == exp(x))
assert(diff(log(x), x) == S(1)/x)

# test Abs
assert(Abs(S(-10)) == S(10))
assert(Abs(S(10)) == S(10))
assert(Abs(S(-1)*x) == Abs(x))

# test composite functions
a: S = exp(x)
b: S = sin(a)
c: S = cos(b)
d: S = log(c)
e: S = Abs(d)
print(e)
assert(e == Abs(log(cos(sin(exp(x))))))

test_elementary_functions()
71 changes: 48 additions & 23 deletions src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -2702,7 +2702,7 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
out += func_name; break; \
}

std::string performSymbolicOperation(const std::string& functionName, const ASR::IntrinsicFunction_t& x) {
std::string performBinarySymbolicOperation(const std::string& functionName, const ASR::IntrinsicFunction_t& x) {
headers.insert("symengine/cwrapper.h");
std::string indent(4, ' ');
LCOMPILERS_ASSERT(x.n_args == 2);
Expand All @@ -2727,6 +2727,23 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
return target;
}

std::string performUnarySymbolicOperation(const std::string& functionName, const ASR::IntrinsicFunction_t& x) {
headers.insert("symengine/cwrapper.h");
std::string indent(4, ' ');
LCOMPILERS_ASSERT(x.n_args == 1);
std::string target = symengine_queue.push();
std::string target_src = symengine_src;
this->visit_expr(*x.m_args[0]);
std::string arg1 = src;
std::string arg1_src = symengine_src;
if (ASR::is_a<ASR::Var_t>(*x.m_args[0])) {
symengine_queue.pop();
}
symengine_src = target_src + arg1_src;
symengine_src += indent + functionName + "(" + target + ", " + arg1 + ");\n";
return target;
}

void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t &x) {
std::string out;
std::string indent(4, ' ');
Expand All @@ -2745,27 +2762,51 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
SET_INTRINSIC_NAME(Exp2, "exp2");
SET_INTRINSIC_NAME(Expm1, "expm1");
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd)): {
src = performSymbolicOperation("basic_add", x);
src = performBinarySymbolicOperation("basic_add", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSub)): {
src = performSymbolicOperation("basic_sub", x);
src = performBinarySymbolicOperation("basic_sub", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicMul)): {
src = performSymbolicOperation("basic_mul", x);
src = performBinarySymbolicOperation("basic_mul", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiv)): {
src = performSymbolicOperation("basic_div", x);
src = performBinarySymbolicOperation("basic_div", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPow)): {
src = performSymbolicOperation("basic_pow", x);
src = performBinarySymbolicOperation("basic_pow", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiff)): {
src = performSymbolicOperation("basic_diff", x);
src = performBinarySymbolicOperation("basic_diff", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSin)): {
src = performUnarySymbolicOperation("basic_sin", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicCos)): {
src = performUnarySymbolicOperation("basic_cos", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicLog)): {
src = performUnarySymbolicOperation("basic_log", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExp)): {
src = performUnarySymbolicOperation("basic_exp", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAbs)): {
src = performUnarySymbolicOperation("basic_abs", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand)): {
src = performUnarySymbolicOperation("basic_expand", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPi)): {
Expand Down Expand Up @@ -2794,22 +2835,6 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
src = target;
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand)): {
headers.insert("symengine/cwrapper.h");
LCOMPILERS_ASSERT(x.n_args == 1);
std::string target = symengine_queue.push();
std::string target_src = symengine_src;
this->visit_expr(*x.m_args[0]);
std::string arg1 = src;
std::string arg1_src = symengine_src;
if (ASR::is_a<ASR::Var_t>(*x.m_args[0])) {
symengine_queue.pop();
}
symengine_src = target_src + arg1_src;
symengine_src += indent + "basic_expand(" + target + ", " + arg1 + ");\n";
src = target;
return;
}
default : {
throw LCompilersException("IntrinsicFunction: `"
+ ASRUtils::get_intrinsic_name(x.m_intrinsic_id)
Expand Down
118 changes: 80 additions & 38 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ enum class IntrinsicFunctions : int64_t {
SymbolicInteger,
SymbolicDiff,
SymbolicExpand,
SymbolicSin,
SymbolicCos,
SymbolicLog,
SymbolicExp,
SymbolicAbs,
Sum,
// ...
};
Expand Down Expand Up @@ -2169,45 +2174,52 @@ namespace SymbolicInteger {
}
} // namespace SymbolicInteger

namespace SymbolicExpand {

static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) {
const Location& loc = x.base.base.loc;
ASRUtils::require_impl(x.n_args == 1,
"SymbolicExpand must have exactly 1 input argument",
loc, diagnostics);

ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]);
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*input_type),
"SymbolicExpand expects an argument of type SymbolicExpression",
x.base.base.loc, diagnostics);
}

static inline ASR::expr_t *eval_SymbolicExpand(Allocator &/*al*/,
const Location &/*loc*/, Vec<ASR::expr_t*>& /*args*/) {
// TODO
return nullptr;
}

static inline ASR::asr_t* create_SymbolicExpand(Allocator& al, const Location& loc,
Vec<ASR::expr_t*>& args,
const std::function<void (const std::string &, const Location &)> err) {
if (args.size() != 1) {
err("Intrinsic expand function accepts exactly 1 argument", loc);
}

ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]);
if(!ASR::is_a<ASR::SymbolicExpression_t>(*argtype)) {
err("Argument of SymbolicExpand function must be of type SymbolicExpression",
args[0]->base.loc);
}

ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc));
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_SymbolicExpand,
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand), 0, to_type);
}
#define create_symbolic_unary_macro(X) \
namespace X { \
\
static inline void verify_args(const ASR::IntrinsicFunction_t& x, \
diag::Diagnostics& diagnostics) { \
const Location& loc = x.base.base.loc; \
ASRUtils::require_impl(x.n_args == 1, \
#X " must have exactly 1 input argument", loc, diagnostics); \
\
ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); \
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*input_type), \
#X " expects an argument of type SymbolicExpression", loc, diagnostics); \
} \
\
static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \
Vec<ASR::expr_t*> &/*args*/) { \
/*TODO*/ \
return nullptr; \
} \
\
static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \
Vec<ASR::expr_t*>& args, \
const std::function<void (const std::string &, const Location &)> err) { \
if (args.size() != 1) { \
err("Intrinsic " #X " function accepts exactly 1 argument", loc); \
} \
\
ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]); \
if (!ASR::is_a<ASR::SymbolicExpression_t>(*argtype)) { \
err("Argument of " #X " function must be of type SymbolicExpression", \
args[0]->base.loc); \
} \
\
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc)); \
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_##X, \
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::X), 0, to_type); \
} \
\
} // namespace X

} // namespace SymbolicExpand
create_symbolic_unary_macro(SymbolicSin)
create_symbolic_unary_macro(SymbolicCos)
create_symbolic_unary_macro(SymbolicLog)
create_symbolic_unary_macro(SymbolicExp)
create_symbolic_unary_macro(SymbolicAbs)
create_symbolic_unary_macro(SymbolicExpand)

namespace IntrinsicFunctionRegistry {

Expand Down Expand Up @@ -2275,6 +2287,16 @@ namespace IntrinsicFunctionRegistry {
{nullptr, &SymbolicDiff::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand),
{nullptr, &SymbolicExpand::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSin),
{nullptr, &SymbolicSin::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicCos),
{nullptr, &SymbolicCos::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicLog),
{nullptr, &SymbolicLog::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExp),
{nullptr, &SymbolicExp::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAbs),
{nullptr, &SymbolicAbs::verify_args}},
};

static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
Expand Down Expand Up @@ -2333,6 +2355,16 @@ namespace IntrinsicFunctionRegistry {
"SymbolicDiff"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand),
"SymbolicExpand"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSin),
"SymbolicSin"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicCos),
"SymbolicCos"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicLog),
"SymbolicLog"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExp),
"SymbolicExp"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAbs),
"SymbolicAbs"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Any),
"any"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Sum),
Expand Down Expand Up @@ -2372,6 +2404,11 @@ namespace IntrinsicFunctionRegistry {
{"SymbolicInteger", {&SymbolicInteger::create_SymbolicInteger, &SymbolicInteger::eval_SymbolicInteger}},
{"diff", {&SymbolicDiff::create_SymbolicDiff, &SymbolicDiff::eval_SymbolicDiff}},
{"expand", {&SymbolicExpand::create_SymbolicExpand, &SymbolicExpand::eval_SymbolicExpand}},
{"SymbolicSin", {&SymbolicSin::create_SymbolicSin, &SymbolicSin::eval_SymbolicSin}},
{"SymbolicCos", {&SymbolicCos::create_SymbolicCos, &SymbolicCos::eval_SymbolicCos}},
{"SymbolicLog", {&SymbolicLog::create_SymbolicLog, &SymbolicLog::eval_SymbolicLog}},
{"SymbolicExp", {&SymbolicExp::create_SymbolicExp, &SymbolicExp::eval_SymbolicExp}},
{"SymbolicAbs", {&SymbolicAbs::create_SymbolicAbs, &SymbolicAbs::eval_SymbolicAbs}},
};

static inline bool is_intrinsic_function(const std::string& name) {
Expand Down Expand Up @@ -2488,6 +2525,11 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(SymbolicInteger)
INTRINSIC_NAME_CASE(SymbolicDiff)
INTRINSIC_NAME_CASE(SymbolicExpand)
INTRINSIC_NAME_CASE(SymbolicSin)
INTRINSIC_NAME_CASE(SymbolicCos)
INTRINSIC_NAME_CASE(SymbolicLog)
INTRINSIC_NAME_CASE(SymbolicExp)
INTRINSIC_NAME_CASE(SymbolicAbs)
INTRINSIC_NAME_CASE(Sum)
default : {
throw LCompilersException("pickle: intrinsic_id not implemented");
Expand Down
12 changes: 10 additions & 2 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7265,15 +7265,23 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
}

if (!s) {
std::string intrinsic_name = call_name;
std::set<std::string> not_cpython_builtin = {
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand",
"sum" // For sum called over lists
};
if (ASRUtils::IntrinsicFunctionRegistry::is_intrinsic_function(call_name) &&
std::set<std::string> symbolic_functions = {
"sin", "cos", "log", "exp", "Abs"
};
if ((symbolic_functions.find(call_name) != symbolic_functions.end()) &&
imported_functions[call_name] == "sympy"){
intrinsic_name = "Symbolic" + std::string(1, std::toupper(call_name[0])) + call_name.substr(1);
}
if (ASRUtils::IntrinsicFunctionRegistry::is_intrinsic_function(intrinsic_name) &&
(not_cpython_builtin.find(call_name) == not_cpython_builtin.end() ||
imported_functions.find(call_name) != imported_functions.end() )) {
ASRUtils::create_intrinsic_function create_func =
ASRUtils::IntrinsicFunctionRegistry::get_create_function(call_name);
ASRUtils::IntrinsicFunctionRegistry::get_create_function(intrinsic_name);
Vec<ASR::expr_t*> args_; args_.reserve(al, x.n_args);
visit_expr_list(x.m_args, x.n_args, args_);
if (ASRUtils::is_array(ASRUtils::expr_type(args_[0])) &&
Expand Down

0 comments on commit 39e1a26

Please sign in to comment.