001    // --- BEGIN LICENSE BLOCK ---
002    /* 
003     * Copyright (c) 2009, Mikio L. Braun
004     * All rights reserved.
005     * 
006     * Redistribution and use in source and binary forms, with or without
007     * modification, are permitted provided that the following conditions are
008     * met:
009     * 
010     *     * Redistributions of source code must retain the above copyright
011     *       notice, this list of conditions and the following disclaimer.
012     * 
013     *     * Redistributions in binary form must reproduce the above
014     *       copyright notice, this list of conditions and the following
015     *       disclaimer in the documentation and/or other materials provided
016     *       with the distribution.
017     * 
018     *     * Neither the name of the Technische Universit??t Berlin nor the
019     *       names of its contributors may be used to endorse or promote
020     *       products derived from this software without specific prior
021     *       written permission.
022     * 
023     * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
024     * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
025     * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
026     * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
027     * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
028     * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
029     * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
030     * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
031     * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
032     * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
033     * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
034     */
035    // --- END LICENSE BLOCK ---
036    
037    package org.jblas;
038    
039    /**
040     * <p>General functions which are geometric in nature.</p>
041     * 
042     * <p>For example, computing all pairwise squared distances between all columns of a matrix.</p>
043     */
044    public class Geometry {
045            
046            /**
047             * <p>Compute the pairwise squared distances between all columns of the two
048             * matrices.</p>
049             * 
050             * <p>An efficient way to do this is to observe that <i>(x-y)^2 = x^2 - 2xy - y^2</i>
051             * and to then properly carry out the computation with matrices.</p>
052             */
053            public static DoubleMatrix pairwiseSquaredDistances(DoubleMatrix X, DoubleMatrix Y) {
054                    if (X.rows != Y.rows)
055                            throw new IllegalArgumentException(
056                                            "Matrices must have same number of rows");
057            
058                    DoubleMatrix XX = X.mul(X).columnSums();
059                    DoubleMatrix YY = Y.mul(Y).columnSums();
060            
061                    DoubleMatrix Z = X.transpose().mmul(Y);
062                    Z.muli(-2.0); //Z.print();
063                    Z.addiColumnVector(XX);
064                    Z.addiRowVector(YY);
065            
066                    return Z;
067            }
068    
069            /** Center a vector (subtract mean from all elements (in-place). */
070            public static DoubleMatrix center(DoubleMatrix x) {
071                    return x.subi(x.mean());
072            }
073            
074            /** Center the rows of a matrix (in-place). */
075            public static DoubleMatrix centerRows(DoubleMatrix x) {
076                    DoubleMatrix temp = new DoubleMatrix(x.columns);
077                    for (int r = 0; r < x.rows; r++)
078                            x.putRow(r, center(x.getRow(r, temp)));
079                    return x;
080            }
081            
082            /** Center the columns of a matrix (in-place). */
083            public static DoubleMatrix centerColumns(DoubleMatrix x) {
084                    DoubleMatrix temp = new DoubleMatrix(x.rows);
085                    for (int c = 0; c < x.columns; c++)
086                            x.putColumn(c, center(x.getColumn(c, temp)));
087                    return x;
088            }
089            
090            /** Normalize a vector (scale such that its Euclidean norm is 1) (in-place). */
091            public static DoubleMatrix normalize(DoubleMatrix x) {
092                    return x.divi(x.norm2());
093            }
094    
095            /** Normalize the rows of a matrix (in-place). */
096            public static DoubleMatrix normalizeRows(DoubleMatrix x) {
097                    DoubleMatrix temp = new DoubleMatrix(x.columns);
098                    for (int r = 0; r < x.rows; r++)
099                            x.putRow(r, normalize(x.getRow(r, temp)));
100                    return x;
101            }
102            
103            /** Normalize the columns of a matrix (in-place). */
104            public static DoubleMatrix normalizeColumns(DoubleMatrix x) {
105                    DoubleMatrix temp = new DoubleMatrix(x.rows);
106                    for (int c = 0; c < x.columns; c++)
107                            x.putColumn(c, normalize(x.getColumn(c, temp)));
108                    return x;
109            }
110    
111    //BEGIN
112      // The code below has been automatically generated.
113      // DO NOT EDIT!
114            
115            /**
116             * <p>Compute the pairwise squared distances between all columns of the two
117             * matrices.</p>
118             * 
119             * <p>An efficient way to do this is to observe that <i>(x-y)^2 = x^2 - 2xy - y^2</i>
120             * and to then properly carry out the computation with matrices.</p>
121             */
122            public static FloatMatrix pairwiseSquaredDistances(FloatMatrix X, FloatMatrix Y) {
123                    if (X.rows != Y.rows)
124                            throw new IllegalArgumentException(
125                                            "Matrices must have same number of rows");
126            
127                    FloatMatrix XX = X.mul(X).columnSums();
128                    FloatMatrix YY = Y.mul(Y).columnSums();
129            
130                    FloatMatrix Z = X.transpose().mmul(Y);
131                    Z.muli(-2.0f); //Z.print();
132                    Z.addiColumnVector(XX);
133                    Z.addiRowVector(YY);
134            
135                    return Z;
136            }
137    
138            /** Center a vector (subtract mean from all elements (in-place). */
139            public static FloatMatrix center(FloatMatrix x) {
140                    return x.subi(x.mean());
141            }
142            
143            /** Center the rows of a matrix (in-place). */
144            public static FloatMatrix centerRows(FloatMatrix x) {
145                    FloatMatrix temp = new FloatMatrix(x.columns);
146                    for (int r = 0; r < x.rows; r++)
147                            x.putRow(r, center(x.getRow(r, temp)));
148                    return x;
149            }
150            
151            /** Center the columns of a matrix (in-place). */
152            public static FloatMatrix centerColumns(FloatMatrix x) {
153                    FloatMatrix temp = new FloatMatrix(x.rows);
154                    for (int c = 0; c < x.columns; c++)
155                            x.putColumn(c, center(x.getColumn(c, temp)));
156                    return x;
157            }
158            
159            /** Normalize a vector (scale such that its Euclidean norm is 1) (in-place). */
160            public static FloatMatrix normalize(FloatMatrix x) {
161                    return x.divi(x.norm2());
162            }
163    
164            /** Normalize the rows of a matrix (in-place). */
165            public static FloatMatrix normalizeRows(FloatMatrix x) {
166                    FloatMatrix temp = new FloatMatrix(x.columns);
167                    for (int r = 0; r < x.rows; r++)
168                            x.putRow(r, normalize(x.getRow(r, temp)));
169                    return x;
170            }
171            
172            /** Normalize the columns of a matrix (in-place). */
173            public static FloatMatrix normalizeColumns(FloatMatrix x) {
174                    FloatMatrix temp = new FloatMatrix(x.rows);
175                    for (int c = 0; c < x.columns; c++)
176                            x.putColumn(c, normalize(x.getColumn(c, temp)));
177                    return x;
178            }
179    
180    //END
181    }