##############################################################################
#
# Copyright (c) 2001 Zope Corporation and Contributors. All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
import unittest
class TestTaintedString(unittest.TestCase):
def setUp(self):
self.unquoted = '<test attr="&">'
self.quoted = '<test attr="&">'
self.tainted = self._getClass()(self.unquoted)
def _getClass(self):
from ZPublisher.TaintedString import TaintedString
return TaintedString
def testStr(self):
self.assertEquals(str(self.tainted), self.unquoted)
def testRepr(self):
self.assertEquals(repr(self.tainted), repr(self.quoted))
def testCmp(self):
self.assertEquals(cmp(self.tainted, self.unquoted), 0)
self.assertEquals(cmp(self.tainted, 'a'), -1)
self.assertEquals(cmp(self.tainted, '.'), 1)
def testHash(self):
hash = {}
hash[self.tainted] = self.quoted
hash[self.unquoted] = self.unquoted
self.assertEquals(hash[self.tainted], self.unquoted)
def testLen(self):
self.assertEquals(len(self.tainted), len(self.unquoted))
def testGetItem(self):
self.assert_(isinstance(self.tainted[0], self._getClass()))
self.assertEquals(self.tainted[0], '<')
self.failIf(isinstance(self.tainted[-1], self._getClass()))
self.assertEquals(self.tainted[-1], '>')
def testGetSlice(self):
self.assert_(isinstance(self.tainted[0:1], self._getClass()))
self.assertEquals(self.tainted[0:1], '<')
self.failIf(isinstance(self.tainted[1:], self._getClass()))
self.assertEquals(self.tainted[1:], self.unquoted[1:])
def testConcat(self):
self.assert_(isinstance(self.tainted + 'test', self._getClass()))
self.assertEquals(self.tainted + 'test', self.unquoted + 'test')
self.assert_(isinstance('test' + self.tainted, self._getClass()))
self.assertEquals('test' + self.tainted, 'test' + self.unquoted)
def testMultiply(self):
self.assert_(isinstance(2 * self.tainted, self._getClass()))
self.assertEquals(2 * self.tainted, 2 * self.unquoted)
self.assert_(isinstance(self.tainted * 2, self._getClass()))
self.assertEquals(self.tainted * 2, self.unquoted * 2)
def testInterpolate(self):
tainted = self._getClass()('<%s>')
self.assert_(isinstance(tainted % 'foo', self._getClass()))
self.assertEquals(tainted % 'foo', '<foo>')
tainted = self._getClass()('<%s attr="%s">')
self.assert_(isinstance(tainted % ('foo', 'bar'), self._getClass()))
self.assertEquals(tainted % ('foo', 'bar'), '<foo attr="bar">')
def testStringMethods(self):
simple = "capitalize isalpha isdigit islower isspace istitle isupper" \
" lower lstrip rstrip strip swapcase upper".split()
returnsTainted = "capitalize lower lstrip rstrip strip swapcase upper"
returnsTainted = returnsTainted.split()
unquoted = '\tThis is a test '
tainted = self._getClass()(unquoted)
for f in simple:
v = getattr(tainted, f)()
self.assertEquals(v, getattr(unquoted, f)())
if f in returnsTainted:
self.assert_(isinstance(v, self._getClass()))
else:
self.failIf(isinstance(v, self._getClass()))
justify = "center ljust rjust".split()
for f in justify:
v = getattr(tainted, f)(30)
self.assertEquals(v, getattr(unquoted, f)(30))
self.assert_(isinstance(v, self._getClass()))
searches = "find index rfind rindex endswith startswith".split()
searchraises = "index rindex".split()
for f in searches:
v = getattr(tainted, f)('test')
self.assertEquals(v, getattr(unquoted, f)('test'))
if f in searchraises:
self.assertRaises(ValueError, getattr(tainted, f), 'nada')
self.assertEquals(tainted.count('test', 1, -1),
unquoted.count('test', 1, -1))
self.assertEquals(tainted.encode(), unquoted.encode())
self.assert_(isinstance(tainted.encode(), self._getClass()))
self.assertEquals(tainted.expandtabs(10),
unquoted.expandtabs(10))
self.assert_(isinstance(tainted.expandtabs(), self._getClass()))
self.assertEquals(tainted.replace('test', 'spam'),
unquoted.replace('test', 'spam'))
self.assert_(isinstance(tainted.replace('test', '<'), self._getClass()))
self.failIf(isinstance(tainted.replace('test', 'spam'),
self._getClass()))
self.assertEquals(tainted.split(), unquoted.split())
for part in self._getClass()('< < <').split():
self.assert_(isinstance(part, self._getClass()))
for part in tainted.split():
self.failIf(isinstance(part, self._getClass()))
multiline = 'test\n<tainted>'
lines = self._getClass()(multiline).split()
self.assertEquals(lines, multiline.split())
self.assert_(isinstance(lines[1], self._getClass()))
self.failIf(isinstance(lines[0], self._getClass()))
transtable = ''.join(map(chr, range(256)))
self.assertEquals(tainted.translate(transtable),
unquoted.translate(transtable))
self.assert_(isinstance(self._getClass()('<').translate(transtable),
self._getClass()))
self.failIf(isinstance(self._getClass()('<').translate(transtable, '<'),
self._getClass()))
def testQuoted(self):
self.assertEquals(self.tainted.quoted(), self.quoted)
def test_suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestTaintedString, 'test'))
return suite
def main():
unittest.TextTestRunner().run(test_suite())
def debug():
test_suite().debug()
def pdebug():
import pdb
pdb.run('debug()')
if __name__=='__main__':
if len(sys.argv) > 1:
globals()[sys.argv[1]]()
else:
main()
|