/***************************************************************************
equation.cpp - Class for handling equations
-------------------
begin : Sun Oct 21 2001
copyright : (C) 2001 by Jan Rheinlaender
email : jrheinlaender@users.sourceforge.net
***************************************************************************/
/***************************************************************************
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation; either version 2 of the License, or *
* (at your option) any later version. *
* *
***************************************************************************/
#include "equation.h"
#include "unit.h"
#include "printing.h"
#include "expression.h"
#include "msgdriver.h" // *** added in 0.7
//#include "../config/config.h" // *** added in 1.1
#include <sstream>
#include "utils.h" // *** added in 0.8
GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(equation, relational,
print_func<print_context>(&equation::do_print).
print_func<print_latex>(&equation::do_print_ltx));
// *** added in 0.9
int equation::nextlabel = 0;
bool equation_init::called = false;
equation_init equation_init::equation_initializer;
equation_init::equation_init() { // *** added in 0.6
if (called) {
// Removed in 1.3.1, msg_error might not be initialized yet
// msg_error.prio(0) << "Attempt to call equation::init() more than one time!" << endline;
return;
} else {
called = true;
equation::init();
}
} // equation_init::equation_init()
bool equation::initialized = false;
void equation::init() { // *** changed in 1.2
if (initialized) return;
#ifndef _MSC_VER
// *** added in 1.4.1 because MSVC does not initialize Digits before running this code
// The program crashes in numeric_digits& _numeric_digits::operator=(long prec) at for (; it != end; ++it) {
Digits = 9; // This must not be too high, because of things like x^(0.99999999999999995)...
#endif
initialized = true;
} // equation::init()
void equation::clear() { // *** added in 0.8
nextlabel = 0;
Digits = 9;
} // equation::clear()
// constructors
equation::equation() : inherited(numeric(0), numeric(0), relational::equal),
label(""), al(none), type(empty), ltxrepr("") { // *** added in 0.9
// tinfo_key = TINFO_equation; *** removed in 1.2.1
// tinfo_key = &equation::tinfo_static; *** removed in 1.3.1
} // equation::equation()
equation::equation(const expression &le, const expression &ri, const operators op,
const aligntype align, const eqtype t, const std::string &l)
: inherited(le, ri, op), label(l) {
// tinfo_key = TINFO_equation; *** removed in 1.2.1
// tinfo_key = &equation::tinfo_static; *** removed in 1.3.1
type = ((le.is_empty() && ri.is_empty()) ? empty : t);
if ((le.getltxrepr() == "") || (ri.getltxrepr() == "") || (type == empty)) {
ltxrepr = "";
} else {
std::string oper; // The ltxrepr of the two expressions already have ampersands!!
switch(o) {
case relational::equal: oper = "="; break;
case relational::not_equal: oper = "\\neq"; break;
case relational::less: oper = "<"; break;
case relational::less_or_equal: oper = "\\leq"; break;
case relational::greater: oper = ">"; break;
case relational::greater_or_equal: oper = "\\geq"; break;
default: oper = "[invalid operator]";
}
ltxrepr = le.getltxrepr() + oper + ri.getltxrepr();
}
if (le.getampersand()) {
al = ri.getampersand() ? both : onlyleft;
} else if (ri.getampersand()) {
al = le.getampersand() ? both : onlyright;
} else {
al = align; // TODO: In the other cases align is ignored
}
MSG_INFO(3) << "Created equation '" << le << " = " << ri << "' with ltxrepr '" << ltxrepr << "'" << endline;
} // equation::equation()
equation::equation(const equation &eq) : inherited(eq.lh, eq.rh, eq.o),
label(eq.label), al(eq.al), type(eq.type), ltxrepr(eq.ltxrepr) {
// tinfo_key = TINFO_equation; *** removed in 1.2.1
// tinfo_key = &equation::tinfo_static; *** removed in 1.3.1
} // equation::equation(const equation&)
equation &equation::operator=(const equation &other) { // *** added in 0.9
lh = other.lh;
rh = other.rh;
o = other.o;
al = other.al;
label = other.label;
type = other.type;
ltxrepr = other.ltxrepr;
return *this;
} // equation::operator=()
// *** GiNaC standard methods
/* equation::equation(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst) {
// *** added in 0.9, removed in 1.3.1
n.find_string("label", label);
} // equation::equation(const archive_node&, lst&)
*/
void equation::read_archive(const archive_node &n, lst &sym_lst) { // *** added in 0.9, changed in 1.3.1 to read_archive() format
inherited::read_archive(n, sym_lst);
n.find_string("label", label);
//return (new equation(n, sym_lst))->setflag(status_flags::dynallocated);
} // equation::unarchive()
GINAC_BIND_UNARCHIVER(equation); // *** added in 1.3.1
void equation::archive(archive_node &n) const { // *** added in 0.9
inherited::archive(n);
n.add_string("label", label);
} // equation::archive()
int equation::compare_same_type(const basic &other) const { // *** added in 0.9
return inherited::compare_same_type(other);
} // equation::compare_same_type()
// methods
const bool equation::is_assignment() const {
if (msg::info().checkprio(6)) {
msg::info() << endline;
std::ostringstream os;
os << tree << rh << dflt;
msg::info() << os.str();
msg::info() << endline;
}
if (!is_a<symbol>(lh) || (o != relational::equal)) return false;
// *** changed is_a<symbol> to is_a_symbol in 0.4
// *** added check for equals in 0.9
// parse the expression tree and check that it contains only numerics and Units.
return is_quantity(rh);
}
const bool equation::is_temporary() const { // *** added in 0.7
return ((label == "") || ((label[0] == '{') && (label[label.size() - 1] == '}')));
} // equation::is_temporary()
const equation &equation::settemplabel() {
// *** added in 0.7, changed return type in 0.9, changed to use sprintf() in 1.0, changed to char[10] in 1.3 because of buffer overflow
char l[10]; // Hopefully nobody ever creates more than a million equations...
sprintf(l, "%c%7u%c", '{', nextlabel++, '}');
label = l;
// msg::info().setlevel(1);
// MSG_INFO(0) << "Set equation " << *this << " to label " << label << ". Next label: " << nextlabel << endline;
// msg::info().setlevel(-1);
return *this;
} // settemplabel()
const equation &equation::addalign(const aligntype &align) { // *** added in 0.9
switch(al) {
case none: al = align; break;
case onlyleft: al = ((align == left) ? onlyleft : both); break;
case onlyright: al = ((align == right) ? onlyright : both); break;
default: al = align;
}
return *this;
} // equation::addalign()
void equation::print_oper(const print_context &c) const {
switch(o) {
case relational::equal: c.s << "=="; break;
case relational::not_equal: c.s << "!="; break;
case relational::less: c.s << "<"; break;
case relational::less_or_equal: c.s << "<="; break;
case relational::greater: c.s << ">"; break;
case relational::greater_or_equal: c.s << ">="; break;
default: c.s << "[invalid operator]";
}
} // equation::print_oper()
void equation::print_oper_ltx(const print_context &c) const {
std::string oper;
switch(o) {
case relational::equal: oper = "="; break;
case relational::not_equal: oper = "\\neq"; break;
case relational::less: oper = "<"; break;
case relational::less_or_equal: oper = "\\leq"; break;
case relational::greater: oper = ">"; break;
case relational::greater_or_equal: oper = "\\geq"; break;
default: c.s << "[invalid operator]";
}
switch (optstack::options->get(o_eqalign).align) {
case none: { c.s << oper; break; }
case onlyleft: { c.s << "&" << oper; break; }
case onlyright: { c.s << oper << "&"; break; }
case both: { c.s << "&" << oper << "&"; break; }
default: { c.s << "[invalid alignment]"; }
}
} // equation::print_oper()
void equation::do_print(const print_context &c, unsigned level) const { // *** changed in 0.9
if (type == empty) { // *** added in 0.6
c.s << "[empty equation]";
} else {
c.s << expression(lh).print(); // *** changed in 1.3.1, just doing c.s << lh crashes openoffice
print_oper(c);
c.s << expression(rh).print();
}
} // equation::do_print()
const unsigned countampersand(const std::string &s) { // *** added in 0.9
unsigned n = 0;
unsigned pos = 0;
while ((pos = s.find("&", pos)) < s.size()) { n++; pos++; }
return n;
} // countampersand()
const std::string removeampersand(const std::string &s) {
unsigned pos = 0;
std::string result = s;
while ((pos = result.find("&", pos)) < result.size()) { result.erase(pos, 1); pos--; }
return result;
} // removeampersand()
void equation::do_print_ltx(const print_context &c, unsigned level) const { // *** added in 0.9
MSG_INFO(2) << "Printing equation (do_print_ltx) " << lh << " = " << rh << endline;
if (type == empty) {
c.s << "[empty equation]";
} else {
if (optstack::options->get(o_eqginac).boolean) { // *** added in 1.1 for test purposes
c.s << latex << Unit::subst_units(lh, *optstack::options->get(o_units).vec); // *** added optstack::options->get in 1.4.1
print_oper_ltx(c);
c.s << latex << Unit::subst_units(rh, *optstack::options->get(o_units).vec);
} else if (optstack::options->get(o_eqraw).boolean && (ltxrepr != "")) {
// Note: Automatic alignment occurs even in eqraw mode-use eqalign=none etc. to selectively turn it off
switch (optstack::options->get(o_eqalign).align) {
case none: { c.s << removeampersand(ltxrepr); break; }
case onlyleft: { if (countampersand(ltxrepr) != 1) {
std::string lr = removeampersand(ltxrepr);
c.s << lr.insert(lr.find("="), "&");
} else
c.s << ltxrepr;
break;
}
case onlyright: { if (countampersand(ltxrepr) != 1) {
std::string lr = removeampersand(ltxrepr);
c.s << lr.insert(lr.find("=") - 1, "&");
} else
c.s << ltxrepr;
break;
}
case both: { if (countampersand(ltxrepr) != 2) {
std::string lr = removeampersand(ltxrepr);
lr.insert(lr.find("="), "&"); // *** corrected this in 1.2
c.s << lr.insert(lr.find("=") + 1, "&");
} else
c.s << ltxrepr;
break;
}
default: { c.s << "[invalid alignment]"; }
}
} else {
if (optstack::options->get(o_eqalign).align == none) {
print_ltx(Unit::subst_units(lh, *optstack::options->get(o_units).vec), c.s, true); // *** changed in 1.0, added optstack::options->get in 1.4.1
} else if (!optstack::options->get(o_eqchain).boolean) {
print_ltx(Unit::subst_units(lh, *optstack::options->get(o_units).vec), c.s, true); // *** changed in 1.0
} // else omit the lhs
print_oper_ltx(c);
int pos = optstack::options->get(o_eqsplit).integer;
if (pos > 0) {
optstack::options->save(o_eqsplit);
option_value o;
o.expr = new ex((int)(pos - (lh.nops() == 0?1:lh.nops())));// TODO: status_flags::dyn_allocated?
optstack::options->set(o_eqsplit, o); // Adjust the split position *** added in 1.2
MSG_INFO(2) << "Adjusted split position for rhs to "
<< optstack::options->get(o_eqsplit).integer << endline;
print_ltx(Unit::subst_units(rh, *optstack::options->get(o_units).vec), c.s, true); // *** changed in 1.0, added optstack::options->get in 1.4.1
optstack::options->restore(o_eqsplit);
} else {
print_ltx(Unit::subst_units(rh, *optstack::options->get(o_units).vec), c.s, true);
}
}
}
} // equation::do_print_ltx()
const equation equation::subs(const ex &e, unsigned options) const
throw (std::invalid_argument) { // *** added in 0.9
return equation(expression(lh).subs(e, options), expression(rh).subs(e, options),
o, al, derived);
} // equation::subs(const ex&, unsigned)
ex equation::subs(const exmap &m, unsigned options) const
throw (std::invalid_argument) { // *** added in 1.0
return equation(expression(lh).subs(m, options), expression(rh).subs(m, options),
o, al, derived);
} // equation::subs(const exmap &, unsigned)
const equation equation::csubs(const exmap &m, unsigned options) const { // *** added in 0.9
return equation(expression(lh).csubs(m, options), expression(rh).csubs(m,options), o, al, derived);
} // equation::csubs()
const equation equation::csubs(const ex &e, unsigned options) const
throw(std::invalid_argument) {
return equation(expression(lh).csubs(e, options), expression(rh).csubs(e, options), o, al, derived);
} // equation::csubs(const lst&, unsigned)
ex equation::expand(unsigned options) const {
// *** added in 0.7, changed return type to ex in 0.9
// *** added cast to expression in 1.0 to benefit from expand_real_powers
MSG_INFO(1) << "Expanding " << *this << endline;
return equation(expression(lh).expand(options), expression(rh).expand(options), o, al, derived);
} // equation::expand()
const equation equation::evalf() const { // *** added in 0.7
MSG_INFO(3) << "Evaluating " << *this << endline;
return equation(lh.evalf(), rh.evalf(), o, al, derived);
} // equation::evalf()
ex equation::evalm() const { // *** added in 1.0
MSG_INFO(3) << "Evaluating matrices in " << *this << endline;
equation result = *this;
result.lh = lh.evalm();
result.rh = rh.evalm();
return result;
} // equation::evalm()
const equation equation::evalu() const { // *** added in 0.9
return equation(expression(lh).evalu(), expression(rh).evalu(), o, al, derived);
} // equation::evalu()
const equation &equation::eqadd(const expression &add) throw (std::invalid_argument) {
// *** added in 0.9
if (is_a<equation>(add)) {
lh += ex_to<equation>(add).lh;
rh += ex_to<equation>(add).rh;
if (o != ex_to<equation>(add).o)
throw std::invalid_argument("The two equations must have the same operator sign");
if (al != ex_to<equation>(add).al)
msg::warn(1) << "Warning: Alignment types of the equations mismatch" << endline;
} else {
lh += add;
rh += add;
}
type = derived;
return *this;
} // equation::eqadd()
const equation &equation::eqmul(const expression &mul) throw (std::invalid_argument) {
// *** added in 0.9
if (is_a<equation>(mul))
throw std::invalid_argument("Cannot multiply two equations");
if (o != relational::equal)
msg::warn(0) << "Warning: Operator sign might have changed in multiplication" << endline;
lh *= mul;
rh *= mul;
type = derived;
return *this;
} // equation::eqmul()
const equation equation::normal() const { // *** added in 0.9
return equation(lh.normal(), rh.normal(), o, al, derived);
} // equation::normal()
const expression equation::apply_func(const expression &e) const throw (std::invalid_argument) {
// *** added in 0.9
MSG_INFO(1) << "Applying function " << e << " to " << *this << endline;
if (!is_a<func>(e))
throw std::invalid_argument("Argument must be a function name");
return equation(func(ex_to<func>(e).get_name(), lh),
func(ex_to<func>(e).get_name(), rh), o, al, derived);
} // equation::apply_func()
const expression equation::apply_func(const std::string &fname) const throw (std::invalid_argument) {
// *** added in 1.4.1
MSG_INFO(1) << "Applying function " << fname << " to " << *this << endline;
if (!func::is_a_func(fname))
throw std::invalid_argument("Argument must be a function name");
return equation(func(fname, lh), func(fname, rh), o, al, derived);
} // equation::apply_func()
const expression equation::apply_power(const expression &e) const throw (std::invalid_argument) {
// *** added in 1.4.1
MSG_INFO(1) << "Raising " << *this << " to the power of " << e << endline;
return equation(power(lh, e), power(rh, e), o, al, derived);
} // equation::apply_power()
const expression equation::diff(const ex &var, unsigned nth) const
throw (std::logic_error) { // *** added in 1.0
if (is_a<symbol>(var))
return diff(ex_to<symbol>(var), nth);
else if (is_a<func>(var) && ex_to<func>(var).is_pure())
return equation(diff_to_func(lh, ex_to<func>(var), nth),
diff_to_func(rh, ex_to<func>(var), nth), o, al, derived);
else
throw std::logic_error("Can only differentiate with respect to a variable or a pure function!");
} // equation::diff()
const expression equation::diff(const symbol &var, unsigned nth) const {
// *** added in 0.6, changed in 0.9
MSG_INFO(1) << "Calculating " << nth << "nth derivative of " << *this << " to " << var << endline;
func::replace_function_by_func replace_functions; // get rid of GiNaC functions that might have been introduced *** added in 1.3.0
return equation(replace_functions(lh.diff(var, nth)),
replace_functions(rh.diff(var, nth)), o, al, derived); // *** added nth in 1.3.0, this had been forgotten!
// *** changed this->equals() to equal_sign in 0.7
} // equation::diff()
const equation equation::reverse() const {
MSG_INFO(1) << "Reversed equation " << *this << endline;
return equation(rh, lh, o, al, derived);
// *** changed in_eqnarray to equals in 0.5, added derived in 0.6
// *** changed equals() to getequals() in 0.8
} // equation::reverse()
const equation equation::simplify(const std::vector<std::string> &s) const {
// *** changed to handle a vector of simplifications in 0.9
// *** removed bug in 1.2: result.lhs() was result.lh etc. for diff and sum
equation result(*this);
result.setlabel(""); // Clear label to avoid duplicate label errors *** added in 1.4.2
for (std::vector<std::string>::const_iterator i = s.begin(); i != s.end(); i++) {
if (*i == "expand") { // Full expansion
MSG_INFO(1) << "Full expanding " << result << endline;
result = ex_to<equation>(result.expand(expand_options::expand_function_args));
} else if (*i == "expandf") { // Do not expand function args
MSG_INFO(1) << "Expanding " << result << endline;
result = ex_to<equation>(result.expand());
} else if (*i == "eval") { // Evaluation *** added in 0.7
MSG_INFO(1) << "Evaluating " << result << endline;
result = result.evalf();
} else if (*i == "normal") { // Normalization *** added in 0.7
MSG_INFO(1) << "Normalizing " << result << endline;
result = result.normal();
} else if (*i == "collect-common") { // Collecting common factors *** added in 0.8
MSG_INFO(1) << "Collecting common factors in " << result << endline;
result = collect_common_factors(result);
} else if (*i == "unsafe") { // Unsafe simplifications *** added in 0.9
MSG_INFO(1) << "Doing unsafe simplifications in " << result << endline;
result = result.evalu();
} else if (*i == "diff") { // Evalute \diff function objects *** added in 1.0
func::expand_partialdiff expand_diff;
result = equation(expand_diff(result.lhs()), expand_diff(result.rhs()), o, al, derived);
} else if (*i == "sum") { // Evalute \sum function objects *** added in 1.0
func::expand_sum expand_s;
result = equation(expand_s(result.lhs()), expand_s(result.rhs()), o, al, derived);
} else if (*i == "gather-sqrt") { // *** added in 1.2
MSG_INFO(1) << "Gathering square roots in " << result << endline;
gather_sqrt gather_sqrts;
result = equation(gather_sqrts(result.lhs()), gather_sqrts(result.rhs()), o, al, derived);
} else if (*i == "integrate") { // *** added in 1.3.1
MSG_INFO(1) << "Symbolically integrating " << result << endline;
result = equation(expression(result.lhs()).eval_integral(), expression(result.rhs()).eval_integral(), o, al, derived);
MSG_INFO(2) << "Result: " << result << endline;
} else { // Nothing of the above fitted
msg::warn(0) << "Warning: Unknown simplification of type " << *i << endline;
}
}
return result;
} // equation::simplify()
expression qsolve(const std::vector<expression> &coeff, const int num) {
// Helper function for equation::solve()
expression help = coeff[1]/expression((numeric(2) * coeff[2]));
expression det = pow(help, numeric(2)) - coeff[0]/coeff[2];
if (det.is_zero()) {
if (num == 1) msg::warn(1) << "Warning: This equation has only one solution" << endline;
return (-help);
} else {
if (num > 2) msg::warn(1) << "Warning: This equation has only two solutions" << endline;
return ((num == 1) ? -help + pow(det, expression(numeric(1,2)))
: -help - pow(det, expression(numeric(1,2))));
}
} // qsolve()
expression csolve(const std::vector<expression> &coeff, const int num) {
// Helper function for equation::solve() *** added in 1.0
// Use the formulae of Cardan to solve the cubic problem. The solution is done for
// y^3 + 3 p y + 2 q = 0, where y = x + b/(3a)
// Note that sqrt(3) can not be used because it evaluates to a real number immediately
expression q = numeric(1,2) * (numeric(2) * pow(coeff[2], numeric(3)) /
(numeric(27) * pow(coeff[3], numeric(3))) -
coeff[2] * coeff[1] / (numeric(3) * pow(coeff[3], numeric(2))) +
coeff[0]/coeff[3]);
expression p = numeric(1,3) * ((numeric(3) * expression(coeff[3] * coeff[1]) -
pow(coeff[2], numeric(2))) /
(numeric(3) * pow(coeff[3], numeric(2))));
//0) << "Coefficient p : " << p << ", q: " << q << endline;
expression u = pow(-q + pow(pow(q, numeric(2)) + pow(p, numeric(3)), numeric(1,2)), numeric(1,3));
expression v = pow(-q - pow(pow(q, numeric(2)) + pow(p, numeric(3)), numeric(1,2)), numeric(1,3));
expression eps1 = numeric(-1,2) + I * numeric(1,2) * pow(3, numeric(1,2));
expression eps2 = numeric(-1,2) - I * numeric(1,2) * pow(3, numeric(1,2));
//0) << "eps1 = " << eps1 << "; eps2 = " << eps2 << endline;
expression y1 = u + v;
expression y2 = eps1 * u + eps2 * v;
expression y3 = eps2 * u + eps1 * v;
switch (num) {
case 0: // This returns the first solution, also
case 1: return (y1 - expression(coeff[2] / expression(numeric(3) * coeff[3])));
case 2: return (y2 - expression(coeff[2] / expression(numeric(3) * coeff[3])));
case 3: return (y3 - expression(coeff[2] / expression(numeric(3) * coeff[3])));
default:
throw std::invalid_argument("Not more than three solutions exist");
}
} // csolve()
const equation equation::solve(const expression &e, const expression &n) const
throw(std::invalid_argument) {
// *** added in 0.8
MSG_INFO(1) << "Solving " << *this << " for " << e << ", solution #" << n << " requested " <<endline;
if (!is_a<symbol>(e)) throw std::invalid_argument("Solving only works for symbols!");
if (!is_a<numeric>(n)) throw std::invalid_argument("The solution number must be numerical!");
if (!ex_to<numeric>(n).is_pos_integer())
throw std::invalid_argument("The solution number must be a positive integer!");
unsigned num = ex_to<numeric>(n).to_int();
expression problem = (lh - rh).expand().normal(); // *** added normal() and numer() in 1.1
if (!denom(problem).is_equal(1)) {
MSG_WARN(2) << "Warning: The solution may not be defined for some values because the denominator is not 1" << endline;
problem = numer(problem);
}
unsigned degree = problem.degree(e);
if (num > degree) {
msg::warn(0) << "Warning: This equation can only have " << degree
<< " solutions. Returning the last one." << endline;
num = degree;
}
std::vector<expression> coeff;
for (unsigned i = 0; i <= degree; i++) coeff.push_back(problem.coeff(e, i));
// The problem now has the form coeff[0] + coeff[1] * e + ... + coeff[degree] * e^degree
// TODO: What about negative degrees?
if (msg::info().checkprio(2)) {
msg::info() << "Degree of the problem is " << degree << endline << "Coefficients: ";
for (std::vector<expression>::const_iterator i = coeff.begin(); i < coeff.end(); i++)
msg::info() << *i << " ";
msg::info() << endline;
}
switch (degree) {
case 0: {
equation result = *this;
result.setlabel(""); // Clear label to avoid duplicate equations! *** added in 1.4.2
return result;
}
case 1: return equation (e, -1 * (coeff[0]/coeff[1]), o, al, derived);
case 2: return equation (e, qsolve(coeff, num), o, al, derived);
case 3: {
if (coeff[1].is_zero() && coeff[2].is_zero()) { // added in 1.4.1, since csolve produces division by zero
if (num > 1)
msg::warn(0) << "Warning: This equation can only have one solution. Returning it." << endline;
return equation(e, pow(-1 * (coeff[0]/coeff[3]), numeric(1,3)), o, al, derived);
} else {
return equation (e, csolve(coeff, num), o, al, derived); // *** added in 1.0
}
}
case 4: {
if (coeff[1].is_zero() && coeff[3].is_zero()) {
// The problem has the form coeff[0] + coeff[2] * e^2 + coeff[4] * e^4
std::vector<expression> newcoeff;
newcoeff.push_back(coeff[0]);
newcoeff.push_back(coeff[2]);
newcoeff.push_back(coeff[4]);
switch (num) {
case 1: return equation(e, sqrt(qsolve(newcoeff, 1)), o, al, derived);
case 2: return equation(e, -sqrt(qsolve(newcoeff, 1)), o, al, derived);
case 3: return equation(e, sqrt(qsolve(newcoeff, 2)), o, al, derived);
case 4: return equation(e, -sqrt(qsolve(newcoeff, 2)), o, al, derived);
}
}
}
default: {
throw std::invalid_argument("Solving for higher degrees than 3 is not implemented yet.");
// TODO: Check whether degree() - ldegree() > 1 instead
}
}
} // equation::solve()
const equation collect_common_factors(const equation &eq) {
return equation(GiNaC::collect_common_factors(eq.lhs()),
GiNaC::collect_common_factors(eq.rhs()),
eq.getop(), eq.getalign(), derived);
} // collect_common_factors(const equation &)
const ex operator+(const expression &e, const expression &add) throw(std::invalid_argument) {
// *** changed in 0.9
MSG_INFO(3) << "Adding " << add << " to " << e << endline;
if (is_a<equation>(e)) {
equation result = ex_to<equation>(e);
return result.eqadd(add);
} else if (is_a<equation>(add)) {
equation result = ex_to<equation>(add);
return result.eqadd(e);
} else
return operator+(ex(e), ex(add));
} // operator+()
const ex operator-(const expression &e, const expression &sub) throw(std::invalid_argument) {
// *** changed in 0.9
MSG_INFO(3) << "Subtracting " << sub << " from " << e << endline;
if (is_a<equation>(e)) {
equation result = ex_to<equation>(e);
return result.eqadd(sub * numeric(-1));
} else if (is_a<equation>(sub)) {
equation result = ex_to<equation>(sub * numeric(-1));
return result.eqadd(e);
} else
return operator-(ex(e), ex(sub));
}
const ex operator*(const expression &e, const expression &mul) throw(std::invalid_argument) {
// *** changed in 0.9
MSG_INFO(3) << "Multiplying " << e << " with " << mul << endline;
if (mul == 0) throw std::invalid_argument("Multiplication with zero");
if (is_a<equation>(e)) {
equation result = ex_to<equation>(e);
return result.eqmul(mul);
} else if (is_a<equation>(mul)) {
equation result = ex_to<equation>(mul);
return result.eqmul(e);
} else
// Does this make sense??
/* if (is_a<equation>(e) && is_a<equation>(mul)) { // Todo: We should check that both equations are equalities a == b
eqe = ex_to<equation>(e);
eqmul = ex_to<equation>(mul);
return equation(eqe.lhs() * eqmul.lhs() == eqe.rhs() * eqmul.rhs();
} else*/
return operator*(ex(e), ex(mul));
}
const ex operator/(const expression &e, const expression &divisor) throw(std::invalid_argument) {
// *** changed in 0.9
MSG_INFO(3) << "Dividing " << e << " by " << divisor << endline;
if (divisor == 0) throw std::invalid_argument("Division by zero");
if (is_a<equation>(e) && is_a<equation>(divisor)) { // Todo: We should check that both equations are equalities a == b
equation eqdividend = ex_to<equation>(e);
equation eqdivisor = ex_to<equation>(divisor);
if ((eqdividend.getop() == relational::equal) && (eqdivisor.getop() == relational::equal))
return equation(eqdividend.lhs() / eqdivisor.lhs(), eqdividend.rhs() / eqdivisor.rhs(), eqdividend.getop(), eqdividend.getalign(), derived);
else
throw std::invalid_argument("Only equalities can be divided by an equality");
} else if (is_a<equation>(e)) {
equation result = ex_to<equation>(e);
return result.eqmul(pow(divisor, numeric(-1)));
} else if (is_a<equation>(divisor)) {
equation result = ex_to<equation>(pow(divisor, numeric(-1)));
return result.eqmul(e);
} else
return operator/(ex(e), ex(divisor));
}
const ex pow(const expression &e, const expression &exponent) throw (std::invalid_argument) {
// *** added in 0.9, takes care of equations
MSG_INFO(3) << "Calculating power of " << e << " to " << exponent << endline;
if (is_a<equation>(exponent))
throw std::invalid_argument("Cannot calculate power to exponent that is an equation");
if (is_a<equation>(e)) {
equation eq = ex_to<equation>(e);
if (eq.getop() != relational::equal)
msg::warn(0) << "Warning: Operator sign might have changed in exponentiation" << endline;
return equation(pow(eq.lhs(), exponent), pow(eq.rhs(), exponent), eq.getop(), eq.getalign(), derived) ;
} else
return pow(ex(e), ex(exponent));
}
const ex operator-(const expression &e) { // *** added in 0.9
return e * numeric(-1);
}
// new input/output operators
/*message &operator<<(message &ms, const equation &eq) { // *** changed to message stream in 0.7
std::ostringstream os;
eq.print(print_dflt(os)); // *** changed to use print_dflt in 0.9
ms << os.str();
return (ms);
}
*/