quad_forms.py :  » Math » Modular-toolkit-for-Data-Processing » MDP-2.6 » mdp » utils » Python Open Source

Home
Python Open Source
1.3.1.2 Python
2.Ajax
3.Aspect Oriented
4.Blog
5.Build
6.Business Application
7.Chart Report
8.Content Management Systems
9.Cryptographic
10.Database
11.Development
12.Editor
13.Email
14.ERP
15.Game 2D 3D
16.GIS
17.GUI
18.IDE
19.Installer
20.IRC
21.Issue Tracker
22.Language Interface
23.Log
24.Math
25.Media Sound Audio
26.Mobile
27.Network
28.Parser
29.PDF
30.Project Management
31.RSS
32.Search
33.Security
34.Template Engines
35.Test
36.UML
37.USB Serial
38.Web Frameworks
39.Web Server
40.Web Services
41.Web Unit
42.Wiki
43.Windows
44.XML
Python Open Source » Math » Modular toolkit for Data Processing 
Modular toolkit for Data Processing » MDP 2.6 » mdp » utils » quad_forms.py
import mdp
from routines import refcast
numx = mdp.numx
numx_linalg = mdp.numx_linalg

# 10 times machine eps
epsilon = 10*numx.finfo(numx.double).eps

class QuadraticForm(object):
    """
    Define an inhomogeneous quadratic form as 1/2 x'Hx + f'x + c .
    This class implements the quadratic form analysis methods
    presented in:

    Berkes, P. and Wiskott, L. (2006). On the analysis and interpretation
    of inhomogeneous quadratic forms as receptive fields. Neural
    Computation, 18(8): 1868-1895.
    """

    def __init__(self, H, f=None, c=None, dtype='d'):
        """
        The quadratic form is defined as 1/2 x'Hx + f'x + c .
        'dtype' specifies the numerical type of the internal structures.
        """
        local_eps = 10*numx.finfo(numx.dtype(dtype)).eps
        # check that H is almost symmetric
        if not numx.allclose(H, H.T, rtol=100*local_eps, atol=local_eps):
            raise mdp.MDPException('H does not seem to be symmetric')
        
        self.H = refcast(H, dtype)
        if f is None:
            f = numx.zeros((H.shape[0],), dtype=dtype)
        if c is None:
            c = 0
        self.f = refcast(f, dtype)
        self.c = c
        self.dtype = dtype

    def apply(self, x):
        """Apply the quadratic form to the input vectors.
        Return 1/2 x'Hx + f'x + c ."""
        x = numx.atleast_2d(x)
        return (0.5*(mdp.utils.mult(x, self.H.T)*x).sum(axis=1) +
               mdp.utils.mult(x, self.f) + self.c)

    def _eig_sort(self, x):
        E, W = numx_linalg.eig(x)
        E, W = E.real, W.real
        idx = E.argsort()
        E = E.take(idx)
        W = W.take(idx, axis=1)
        return E, W
        
    def get_extrema(self, norm, tol = 1.E-4):
        """
        Find the input vectors xmax and xmin with norm 'nrm' that maximize
        or minimize the quadratic form.

        tol: norm error tolerance
        """
        H, f, c = self.H, self.f, self.c
        if max(abs(f)) < numx.finfo(self.dtype).eps:
            E, W = self._eig_sort(H)
            xmax = W[:, -1]*norm
            xmin = W[:, 0]*norm
        else:    
            H_definite_positive, H_definite_negative = False, False
            E, W = self._eig_sort(H)
            if E[0] >= 0:
                # H is positive definite
                H_definite_positive = True
            elif E[-1] <= 0:
                # H is negative definite
                H_definite_negative = True

            x0 = mdp.numx_linalg.solve(-H, f)
            if H_definite_positive and mdp.utils.norm2(x0) <= norm:
                xmin = x0
                # x0 is a minimum
            else:
                xmin = self._maximize(norm, tol, factor=-1)
            if H_definite_negative and mdp.utils.norm2(x0) <= norm :
                xmax = x0
                # x0 is a maximum
            else:
                xmax = self._maximize(norm, tol, factor=None)

        self.xmax, self.xmin = xmax, xmin
        return xmax, xmin

    def _maximize(self, norm, tol = 1.E-4, x0 = None, factor = None):
        H, f = self.H, self.f
        if factor is not None:
            H = factor*H
            f = factor*f
        if x0 is not None:
            x0 = mdp.utils.refcast(x0, self.dtype)
            f = mdp.utils.mult(H, x0)+ f
            # c = 0.5*x0'*H*x0 + f'*x0 + c -> do we need it?
        mu, V = self._eig_sort(H)
        alpha = mdp.utils.mult(V.T, f).reshape((H.shape[0],))
        # v_i = alpha_i * v_i (alpha is a raw_vector)
        V = V*alpha
        # left bound for lambda
        ll = mu[-1] # eigenvalue's maximum
        # right bound for lambda
        lr = mdp.utils.norm2(f)/norm + ll
        # search by bisection until norm(x)**2 = norm**2
        norm_2 = norm**2
        norm_x_2 = 0
        while abs(norm_x_2-norm_2) > tol and (lr-ll)/lr > epsilon:
            # bisection of the lambda-interval
            lambd = 0.5*(lr-ll)+ll
            # eigenvalues of (lambda*Id - H)^-1
            beta = (lambd-mu)**(-1)
            # solution to the second lagragian equation
            norm_x_2 = (alpha**2*beta**2).sum()
            #%[ll,lr]
            if norm_x_2 > norm_2:
                ll = lambd
            else:
                lr = lambd
        x = (V*beta).sum(axis=1)
        if x0:
            x = x + x0
        return x

    def get_invariances(self, xstar):
        """Compute invariances of the quadratic form at extremum 'xstar'.
        Outputs:

         w  -- w[:,i] is the direction of the i-th invariance 
         nu -- nu[i] second derivative on the sphere in the direction w[:,i]
        """
        
        # find a basis for the tangential plane of the sphere in x+
        # e(1) ... e(N) is the canonical basis for R^N
        r = mdp.utils.norm2(xstar)
        P = numx.eye(xstar.shape[0], dtype=xstar.dtype)
        P[:, 0] = xstar
        Q, R = numx_linalg.qr(P)
        # the orthogonal subspace
        B = Q[:, 1:]
        # restrict the matrix H to the tangential plane
        Ht = mdp.utils.mult(B.T, mdp.utils.mult(self.H, B))
        # compute the invariances
        nu, w = self._eig_sort(Ht)
        nu -= ((mdp.utils.mult(self.H, xstar)*xstar).sum()
               +(self.f*xstar).sum())/(r*r)
        idx = abs(nu).argsort()
        nu = nu[idx]
        w = w[:, idx]
        w = mdp.utils.mult(B, w)
        return w, nu
        
www.java2java.com | Contact Us
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.