#!/usr/bin/env python
# (C) 2000 Huaiyu Zhu <hzhu@users.sourceforge.net>. Licence: GPL
# $I$
"""
Auxiliary stuff for MatPy:
formating
axes removal
"""
#------------------------------------------------------------------
# Formating functions
"??? Is there a better way to define format for complex numbers?"
import string
def shortI(x): return "% -3d" % x
def shortF(x): return "% -6.3g" % x
def shortC(x): return "%8s" % ("% .3g% +.3gj" % (x.real, x.imag))
def longI(x): return "%-11d" % x
def longF(x): return "%-11.8g" % x
def longC(x): return "%19s" % ("% .8g% +.8gj" % (x.real, x.imag))
#------------------------------------------------------------------
formatI = shortI
formatF = shortF
formatC = shortC
def reprI(seq, n=0):
"Return formatted representation of multiarray"
if len(seq.shape) <= 1:
return "[ " + string.join(map(formatI, list(seq)),', ') + " ]"
else:
return "[" + string.join(
map(lambda x,n=n:reprI(x,n+1), seq), ',\n'+' '*(n+1)) + "]"
def strI(seq, n=0):
"Return formated string from multiarray"
if len(seq.shape) <= 1:
return string.join(map(formatI, list(seq)),' ')
else:
return "[" + string.join(
map(lambda x,n=n:strI(x,n+1), seq), '\n'+' '*(n+1)) + " ]"
return s
def reprF(seq, n=0):
"Return formated representation of multiarray"
if len(seq.shape) <= 1:
return "[ " + string.join(map(formatF, list(seq)),', ') + " ]"
else:
return "[" + string.join(
map(lambda x,n=n:reprF(x,n+1), seq), ',\n'+' '*(n+1)) + "]"
def strF(seq, n=0):
"Return formated string from multiarray"
if len(seq.shape) <= 1:
return string.join(map(formatF, list(seq)),' ')
else:
return "[" + string.join(
map(lambda x,n=n:strF(x,n+1), seq),'\n'+' '*(n+1)) + " ]"
return s
def reprC(seq, n=0):
"Return formated representation of multiarray (complex)"
if len(seq.shape) <= 1:
return "[ " + string.join(map(formatC, list(seq)),', ') + "]"
else:
return "[" + string.join(
map(lambda x,n=n:reprC(x,n+1), seq), ',\n'+' '*(n+1)) + "]"
def strC(seq, n=0):
"Return formated string from multiarray (complex)"
if len(seq.shape) <= 1:
return string.join(map(formatC, list(seq)),' ')
else:
return "[" + string.join(
map(lambda x,n=n:strC(x,n+1), seq), '\n'+' '*(n+1)) + " ]"
#------------------------------------------------------------------
import Numeric
def compactAxes(A):
"""Return an array with axes of length 1 removed"""
index = []
for length in Numeric.shape(A):
if length == 1: index.append(0)
else: index.append(Numeric.slice(None))
return A[tuple(index)]
def DelAxes(m):
""" Removes all axes with length one"""
new_shape = []
for length in m.shape:
if length > 1: new_shape.append(length)
return Numeric.reshape(m, new_shape)
|