#!/usr/bin/env python
#
# LiberMate
#
# Copyright (C) 2009  Eric C. Schug
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

__author__ = "Eric C. Schug (schugschug@gmail.com)"
__version__ = "0.1"
__copyright__ = "Copyright (c) 2009 Eric C. Schug"
__license__ = "GNU General Public License"
__revision__ = "$Id$"

# Standard Python
import sys
import re
from copy import copy
import glob
import os
import stat
import pprint


has_syck=False
try:
    import syck
    has_syck=True
    class DumpNoTuple(syck.Dumper):
        def represent_tuple(self, object):
            return _syck.Seq(list(object), tag="tag:python.yaml.org,2002:seq")

except ImportError:
    pass

### class "calcLexer extends Lexer" will generate python
### module "calcLexer" with class "Lexer". 
import MatlabLexer
import MatlabParser
import Mat2Py

import antlr
import CommandLine





def ASTtoTree(ast):
    a=ast
    b=[]
    while a:
        c=ASTtoTree(a.getFirstChild())
        if(c):
            b.append(((MatlabParser._tokenNames[a.getType()],a.getText()),c))
        else:
            b.append(((MatlabParser._tokenNames[a.getType()],a.getText()),None))
        a=a.getNextSibling()
    return b

        

python_priority = {
    "": -3,
    "NAME": -3,
    "=": -3,
    ",": -1,
    "lambda": 0,
    "or": 1,
    "and": 2,
    "not": 3,
    ">": 6,
    "<": 6,
    "<=": 6,
    ">=": 6,
    "==": 6,
    "!=": 6,
    "|": 7,
    "^": 8,
    "&": 9,
    "+": 11,
    "-": 11,
    "*": 12,
    "/": 12,
    " +": 13,
    " -": 13,
    "~": 14,
    "**": 15,
    ".": 16,
    ":": 0,
    }
    
# n = numel(A) returns the number of elements, n, in array A.

class MatParse(MatlabParser.Parser):
    'Higher level logic for parser'
    def __init__(self, *args, **kwargs):
        MatlabParser.Parser.__init__(self, *args, **kwargs)
        self.vars=set()
        self.funcs=set()
        self.is_dot=False
        self.inside_args=False
    def new_scope(self):
        'Create a new function scope'
        self.vars=set()
        self.funcs=set()

    def get_scope(self):
        'Get current scope'
        ast=self.astFactory.create(MatlabParser.SCOPE,", ".join(self.vars)+"\n    # Function calls: "+", ".join(self.funcs))
        print 'The following appear to be variables:'
        print '  ',", ".join(self.vars)
        print 'The following appear to be functions:'
        print '  ',", ".join(self.funcs)
        return ast
    def as_global(self,vartoken):
        'add variable as a global'
        vartoken.setType(MatlabParser.VAR)
        self.vars.add(vartoken.getText())
    def as_persistant(self,vartoken):
        'add variable as a persistant'
        vartoken.setType(MatlabParser.VAR)
        self.vars.add(vartoken.getText())
    def as_var(self,vartoken):
        'NAME token is a variable'
        vartoken.setType(MatlabParser.VAR)
        self.vars.add(vartoken.getText())
    def as_func(self,functoken):
        'NAME token is a function'
        self.funcs.add(functoken.getText())
        
    def var_lookup(self,vartoken):
        'Lookup NAME token in scope and set type'
        #print("Checking var",vartoken.getText())
        if(self.is_dot or (vartoken.getText() in self.vars)):
            vartoken.setType(MatlabParser.VAR)
        else:
            self.as_func(vartoken)
    def var_names(self):
        return list(self.vars)
    def print_tree(self,tree):
        pprint.pprint(ASTtoTree(tree))
        sys.stdout.flush()

import keyword
class Mat2PyTrans(Mat2Py.Walker):
    'Higher level logic for translator'
    def __init__(self):
        Mat2Py.Walker.__init__(self)
        self.nl="\n"
        self.in_var=False
        self.indcnt=0
        self.indent=""
        self.is_simple_rhs=False
        self.is_lhs=False
        self.token_stack=[]
        # Define simple mapping of functions
        self.mapping={'size':'shape.Error',
            'ndims':'Error.ndim',
            'eps':'finfo(float).eps',
            'i':'1j',
            'find':'nonzero',
            'rand':'random.rand',
            'meshgrid':'mgridError',
            'repmat':'tileError',
            'max':'maximumError',
            'norm':'linalg.norm',
            'bitand':'&Error',
            'bitor':'|Error',
            'inv':'linalg.inv',
            'pinv':'linalg.pinv',
            'chol':'linalg.cholesky',
            'eig':'linalg.eig',
            'qr':'scipy.linalg.qr',
            'lu':'scipy.linalg.lu',
            'conjgrad':'scipy.linalg.cg',
            'regress':'linalg.lstsqError',
            'decimate':'scipy.signal.resampleError',
            'assert':'assertError',
            }
    def incr(self):
        'increment indent'
        self.indcnt+=1
        self.indent=" "*(self.indcnt*4)
        self.nl="\n"+self.indent
        #print 'increment "'+self.indent+'"'
    def decr(self):
        'decrement indent'
        self.indcnt-=1
        self.indent=" "*(self.indcnt*4)
        self.nl="\n"+self.indent
        #print 'decrement "'+self.indent+'"'
    def Lookup(self,name):
        'Do simple function mapping'
        if(name in self.mapping):
            return self.mapping[name]
        if(keyword.iskeyword(name)):
            if(name!='assert'):
                name=name+'_rename'
        return name
    def bop(self,op,a,b):
        'binary operator'
        if(not a):
            a='Error'
        if(not b):
            b='Error'
        k=a+op+b
        if(python_priority[self.ptoken]<python_priority[op.strip()]):
            return '('+k+')'
        else:
            return k
    def preop(self,op,a):
        'Prefix operator'
        if(not a):
            a='Error'
        k=op+a
        #print 'preop', self.ptoken
        if(python_priority[self.ptoken]<python_priority[op.strip()]):
            return '('+k+')'
        else:
            return k
    def multiop(self,op,c):
        'multiple operator'
        k=op.join(c)
        if(python_priority[self.ptoken]<python_priority[op.strip()]):
            return '('+k+')'
        else:
            return k
    def colonop(self,a,b,c):
        'colon operator'
        if( self.in_var):
            if(not a and not b):
                sstr=":"
            else:
                if(a.replace('.','').isdigit()):
                    a=str(int(float(a))-1)
                else:
                    a="("+a+")-1"
                if(c):
                    if(c=='xend'):
                        c=''
                    sstr=a+":"+c+":"+b
                else:
                    if(b=='xend'):
                        b=''
                    sstr=a+":"+b
        else:
            if(not a and not b):
                sstr="Error:Error"
            else:
                if(c):
                    if(c.replace('.','').isdigit()):
                        c=str(float(c)+1)
                    else:
                        c="("+c+")+1"
                    sstr="arange("+a+", "+c+", "+b+")"
                else:
                    if(b.replace('.','').isdigit()):
                        b=str(float(b)+1)
                    else:
                        b="("+b+")+1"
                    sstr="arange("+a+", "+b+")"
        return sstr
    def join_args(self,cc,braces=False,ptoken=None):
        def mapper(ii):
            if(":" in ii):
                return ii
            elif(ii.isdigit()):
                return str(int(ii)-1)
            else:
                return '('+ii+')-1'
        if(self.in_var or self.is_lhs or ptoken=="VAR"):
            if(braces):
                sstr=".cell["
            else:
                sstr="["
            
            cc=[mapper(ii) for ii in cc]
            k=",".join(cc)
            if(k==':' and not self.is_lhs and not braces):
                return '.flatten(1)'
            else:
                sstr+=k
            if(braces):
                sstr+="]"
            else:
                sstr+="]"
        else:
            if(braces):
                sstr=".cell_getattr("+", ".join(cc)+")"
            else:
                sstr="("+", ".join(cc)+")"
        return sstr

default_header="""
from numpy import *
import scipy
# if available import pylab (from matlibplot)
try:
    from pylab import *
except ImportError:
    pass

"""

class MainApp(CommandLine.App):
    usage_str='[options] [matfile]...'
    about_str='Translates MATLAB file(s) matfile (.m) to Python files (.py)'

    def __init__(self):
         
        # Specify configuration options build an array of options defined as
        # [Name, shortkey, description, type, default]
        #type can be 'str','dir','file','bool','num', or a list of strings (enumeration)
        self.config_options=[
            ['help','h','display help and quit','bool',False],
            ['astdump','','dump Abstract Syntax Tree to file, .ast','bool',False],
            ['headerfile','','load header from ARG','file',None],
            #TODO ['output','o','log output to file ARG','file','libermate.log'],
            #TODO ['quite','','run silently hush all but prompts','bool',False],
        ]
        args=self.command_line()
        self.files=[]
        for ifile in args:
            #files=glob.glob(ifile)
            files=[ifile]
            self.files.extend(files)
        for filename in self.files:
            try:
                stat_info=os.stat(filename)
            except OSError, desc:
                print 'Error: could not open file %s' % filename
                sys.exit(2)
            if(stat.S_ISDIR(stat_info[stat.ST_MODE])):
                print 'Error: %s is a directory but should be a file' % filename
                sys.exit(2)
            if filename.replace('.m','.py') == filename:
                print "Error: file %s must have a .m suffix" % filename
                sys.exit(2)
                
        #print self.files
        self.main()
    def main(self):
        if(self.headerfile):
            f=open(self.headerfile)
            header=''.join(f.readlines())
            f.close()
        else:
            header=default_header
        for filename in self.files:
            print '\n'+'-'*20
            print 'Opening File',filename
            f = file(filename, "r")
            lexer = MatlabLexer.Lexer(f)          ### create a lexer for calculator
            print 'Starting Parser'
            p = MatParse(lexer)
            p.script()
            print 'Parser Complete'
            a=p.getAST()
            b=ASTtoTree(a)
            
            if(self.astdump):
                outfile=filename.replace('.m','.ast')
                print 'writing to file',outfile
                f=open(outfile,'w')
                if(has_syck and False):
                    f.write(syck.dump(b,Dumper=DumpNoTuple))
                else:
                    f.write(pprint.pformat( b))
                f.close()
            #c=ASTtoXML(a)
            
            walk=Mat2PyTrans()
            print "Starting Translator"
            s=walk.script(a)
            #Simple conversions
            s=re.sub(r'pi\(\)','pi',s)
            s=re.sub(r'Inf\(\)','inf',s)
            s=re.sub(r'nan\(\)','nan',s)
            s=re.sub(r'matdiv\((.+?),\ (\d+)\)',r'(\1)/\2',s)
            s=re.sub(r'matdiv\((\d+),\ ',r'\1/(',s)
            s=re.sub(r'dot\((.+?),\ (\d+)\)',r'(\1)*\2',s)
            s=re.sub(r'dot\((\d+),\ ',r'\1*(',s)
            s=re.sub(r'shape\.Error\(([\w\.]+),\ ([\w\.]+)\)',r'\1.shape[\2-1]',s)
            s=re.sub(r'shape\.Error\((\w+)\)',r'\1.shape',s)
            s=re.sub(r'\.flatten\(1\)\.conj\(\)\.T',r'.flatten(0).conj()',s)
            s=re.sub(r'\.flatten\(1\)\.T',r'.flatten(0)',s)
            s=re.sub(r'\.flatten\(1\)\.T',r'.flatten(0)',s)
            print 'Translation Complete'
            #print p.var_names()
            
            outfile=filename.replace('.m','.py')
            print 'writing to file',outfile
            f=open(outfile,'w')
            f.write( header )
            f.write(s)
            f.close()
        
def testLexer(files):
    'do quick scan of selected files'
    #files=glob.glob('/home/eric/Downloads/mpi-ikl-simplemkl-1.0/*.m')
    quick_scan(files)
    
def testLexer(filename):
    'test parsing of specified file'

    f = file(filename, "r")
    lexer = MatlabLexer.Lexer(f)          ### create a lexer for calculator
    pcount=0
    for token in lexer:
        ## do something with token
        print token.getText(),
        if token.getType() in [MatlabParser.END,MatlabParser.ARRAY_END,MatlabParser.STRING,MatlabParser.TRANS]:
            print "\\"+str(pcount)+MatlabParser._tokenNames[token.getType()]+'/',
        if(token.getType() in [MatlabParser.LPAREN,MatlabParser.LBRACE,MatlabParser.ATPAREN]):
            pcount+=1
            print '\\'+str(pcount)+'/',
        if(token.getType() in [MatlabParser.RPAREN,MatlabParser.RBRACE]):
            pcount-=1
            print '\\'+str(pcount)+'/',

def testParser(files):
    'Test parsing of specified files'
    
    for filename in files:
        f = file(filename, "r")
        lexer = MatlabLexer.Lexer(f)          ### create a lexer for calculator
        p = MatParse(lexer)
        p.script()
        a=p.getAST()
        b=ASTtoTree(a)

def  main():
    app=MainApp()

if __name__ == "__main__":
    main()
    
