001    /*
002     * $HeadURL: http://juliusdavies.ca/svn/not-yet-commons-ssl/tags/commons-ssl-0.3.11/src/java/org/apache/commons/ssl/RMISocketFactoryImpl.java $
003     * $Revision: 144 $
004     * $Date: 2009-05-25 11:14:29 -0700 (Mon, 25 May 2009) $
005     *
006     * ====================================================================
007     * Licensed to the Apache Software Foundation (ASF) under one
008     * or more contributor license agreements.  See the NOTICE file
009     * distributed with this work for additional information
010     * regarding copyright ownership.  The ASF licenses this file
011     * to you under the Apache License, Version 2.0 (the
012     * "License"); you may not use this file except in compliance
013     * with the License.  You may obtain a copy of the License at
014     *
015     *   http://www.apache.org/licenses/LICENSE-2.0
016     *
017     * Unless required by applicable law or agreed to in writing,
018     * software distributed under the License is distributed on an
019     * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
020     * KIND, either express or implied.  See the License for the
021     * specific language governing permissions and limitations
022     * under the License.
023     * ====================================================================
024     *
025     * This software consists of voluntary contributions made by many
026     * individuals on behalf of the Apache Software Foundation.  For more
027     * information on the Apache Software Foundation, please see
028     * <http://www.apache.org/>.
029     *
030     */
031    
032    package org.apache.commons.ssl;
033    
034    import javax.net.ServerSocketFactory;
035    import javax.net.SocketFactory;
036    import javax.net.ssl.SSLException;
037    import javax.net.ssl.SSLPeerUnverifiedException;
038    import javax.net.ssl.SSLProtocolException;
039    import javax.net.ssl.SSLSocket;
040    import java.io.EOFException;
041    import java.io.IOException;
042    import java.io.InterruptedIOException;
043    import java.net.DatagramSocket;
044    import java.net.InetAddress;
045    import java.net.NetworkInterface;
046    import java.net.ServerSocket;
047    import java.net.Socket;
048    import java.net.SocketException;
049    import java.net.UnknownHostException;
050    import java.rmi.server.RMISocketFactory;
051    import java.security.GeneralSecurityException;
052    import java.security.cert.X509Certificate;
053    import java.util.Arrays;
054    import java.util.Collections;
055    import java.util.Enumeration;
056    import java.util.HashMap;
057    import java.util.Iterator;
058    import java.util.LinkedList;
059    import java.util.Map;
060    import java.util.Set;
061    import java.util.SortedSet;
062    import java.util.TreeMap;
063    import java.util.TreeSet;
064    
065    
066    /**
067     * An RMISocketFactory ideal for using RMI over SSL.  The server secures both
068     * the registry and the remote objects.  The client assumes that either both
069     * the registry and the remote objects will use SSL, or both will use
070     * plain-socket.  The client is able to auto detect plain-socket registries
071     * and downgrades itself to accomodate those.
072     * <p/>
073     * Unlike most existing RMI over SSL solutions in use (including Java 5's
074     * javax.rmi.ssl.SslRMIClientSocketFactory), this one does proper SSL hostname
075     * verification.  From the client perspective this is straighforward.  From
076     * the server perspective we introduce a clever trick:  we perform an initial
077     * "hostname verification" by trying the current value of
078     * "java.rmi.server.hostname" against our server certificate.  If the
079     * "java.rmi.server.hostname" System Property isn't set, we set it ourselves
080     * using the CN value we extract from our server certificate!  (Some
081     * complications arise should a wildcard certificate show up, but we try our
082     * best to deal with those).
083     * <p/>
084     * An SSL server cannot be started without a private key.  We have defined some
085     * default behaviour for trying to find a private key to use that we believe
086     * is convenient and sensible:
087     * <p/>
088     * If running from inside Tomcat, we try to re-use Tomcat's private key and
089     * certificate chain (assuming Tomcat-SSL on port 8443 is enabled).  If this
090     * isn't available, we look for the "javax.net.ssl.keyStore" System property.
091     * Finally, if that isn't available, we look for "~/.keystore" and assume
092     * a password of "changeit".
093     * <p/>
094     * If after all these attempts we still failed to find a private key, the
095     * RMISocketFactoryImpl() constructor will throw an SSLException.
096     *
097     * @author Credit Union Central of British Columbia
098     * @author <a href="http://www.cucbc.com/">www.cucbc.com</a>
099     * @author <a href="mailto:juliusdavies@cucbc.com">juliusdavies@cucbc.com</a>
100     * @since 22-Apr-2005
101     */
102    public class RMISocketFactoryImpl extends RMISocketFactory {
103        public final static String RMI_HOSTNAME_KEY = "java.rmi.server.hostname";
104        private final static LogWrapper log = LogWrapper.getLogger(RMISocketFactoryImpl.class);
105    
106        private volatile SocketFactory defaultClient;
107        private volatile ServerSocketFactory sslServer;
108        private volatile String localBindAddress = null;
109        private volatile int anonymousPort = 31099;
110        private Map clientMap = new TreeMap();
111        private Map serverSockets = new HashMap();
112        private final SocketFactory plainClient = SocketFactory.getDefault();
113    
114        public RMISocketFactoryImpl() throws GeneralSecurityException, IOException {
115            this(true);
116        }
117    
118        /**
119         * @param createDefaultServer If false, then we only set the default
120         *                            client, and the default server is set to null.
121         *                            If true, then a default server is also created.
122         * @throws GeneralSecurityException bad things
123         * @throws IOException              bad things
124         */
125        public RMISocketFactoryImpl(boolean createDefaultServer)
126            throws GeneralSecurityException, IOException {
127            SSLServer defaultServer = createDefaultServer ? new SSLServer() : null;
128            SSLClient defaultClient = new SSLClient();
129    
130            // RMI calls to localhost will not check that host matches CN in
131            // certificate.  Hopefully this is acceptable.  (The registry server
132            // will followup the registry lookup with the proper DNS name to get
133            // the remote object, anyway).
134            HostnameVerifier verifier = HostnameVerifier.DEFAULT_AND_LOCALHOST;
135            defaultClient.setHostnameVerifier(verifier);
136            if (defaultServer != null) {
137                defaultServer.setHostnameVerifier(verifier);
138                // The RMI server will try to re-use Tomcat's "port 8443" SSL
139                // Certificate if possible.
140                defaultServer.useTomcatSSLMaterial();
141                X509Certificate[] x509 = defaultServer.getAssociatedCertificateChain();
142                if (x509 == null || x509.length < 1) {
143                    throw new SSLException("Cannot initialize RMI-SSL Server: no KeyMaterial!");
144                }
145                setServer(defaultServer);
146            }
147            setDefaultClient(defaultClient);
148        }
149    
150        public void setServer(ServerSocketFactory f)
151            throws GeneralSecurityException, IOException {
152            this.sslServer = f;
153            if (f instanceof SSLServer) {
154                final HostnameVerifier VERIFIER;
155                VERIFIER = HostnameVerifier.DEFAULT_AND_LOCALHOST;
156    
157                final SSLServer ssl = (SSLServer) f;
158                final X509Certificate[] chain = ssl.getAssociatedCertificateChain();
159                String[] cns = Certificates.getCNs(chain[0]);
160                String[] subjectAlts = Certificates.getDNSSubjectAlts(chain[0]);
161                LinkedList names = new LinkedList();
162                if (cns != null && cns.length > 0) {
163                    // Only first CN is used.  Not going to get into the IE6 nonsense
164                    // where all CN values are used.
165                    names.add(cns[0]);
166                }
167                if (subjectAlts != null && subjectAlts.length > 0) {
168                    names.addAll(Arrays.asList(subjectAlts));
169                }
170    
171                String rmiHostName = System.getProperty(RMI_HOSTNAME_KEY);
172                // If "java.rmi.server.hostname" is already set, don't mess with it.
173                // But blowup if it's not going to work with our SSL Server
174                // Certificate!
175                if (rmiHostName != null) {
176                    try {
177                        VERIFIER.check(rmiHostName, cns, subjectAlts);
178                    }
179                    catch (SSLException ssle) {
180                        String s = ssle.toString();
181                        throw new SSLException(RMI_HOSTNAME_KEY + " of " + rmiHostName + " conflicts with SSL Server Certificate: " + s);
182                    }
183                } else {
184                    // If SSL Cert only contains one non-wild name, just use that and
185                    // hope for the best.
186                    boolean hopingForBest = false;
187                    if (names.size() == 1) {
188                        String name = (String) names.get(0);
189                        if (!name.startsWith("*")) {
190                            System.setProperty(RMI_HOSTNAME_KEY, name);
191                            log.warn("commons-ssl '" + RMI_HOSTNAME_KEY + "' set to '" + name + "' as found in my SSL Server Certificate.");
192                            hopingForBest = true;
193                        }
194                    }
195                    if (!hopingForBest) {
196                        // Help me, Obi-Wan Kenobi; you're my only hope.  All we can
197                        // do now is grab our internet-facing addresses, reverse-lookup
198                        // on them, and hope that one of them validates against our
199                        // server cert.
200                        Set s = getMyInternetFacingIPs();
201                        Iterator it = s.iterator();
202                        while (it.hasNext()) {
203                            String name = (String) it.next();
204                            try {
205                                VERIFIER.check(name, cns, subjectAlts);
206                                System.setProperty(RMI_HOSTNAME_KEY, name);
207                                log.warn("commons-ssl '" + RMI_HOSTNAME_KEY + "' set to '" + name + "' as found by reverse-dns against my own IP.");
208                                hopingForBest = true;
209                                break;
210                            }
211                            catch (SSLException ssle) {
212                                // next!
213                            }
214                        }
215                    }
216                    if (!hopingForBest) {
217                        throw new SSLException("'" + RMI_HOSTNAME_KEY + "' not present.  Must work with my SSL Server Certificate's CN field: " + names);
218                    }
219                }
220            }
221            trustOurself();
222        }
223    
224        public void setLocalBindAddress(String localBindAddress) {
225            this.localBindAddress = localBindAddress;
226        }
227    
228        public void setAnonymousPort(int port) {
229            this.anonymousPort = port;
230        }
231    
232        public void setDefaultClient(SocketFactory f)
233            throws GeneralSecurityException, IOException {
234            this.defaultClient = f;
235            trustOurself();
236        }
237    
238        public void setClient(String host, SocketFactory f)
239            throws GeneralSecurityException, IOException {
240            if (f != null && sslServer != null) {
241                boolean clientIsCommonsSSL = f instanceof SSLClient;
242                boolean serverIsCommonsSSL = sslServer instanceof SSLServer;
243                if (clientIsCommonsSSL && serverIsCommonsSSL) {
244                    SSLClient c = (SSLClient) f;
245                    SSLServer s = (SSLServer) sslServer;
246                    trustEachOther(c, s);
247                }
248            }
249            Set names = hostnamePossibilities(host);
250            Iterator it = names.iterator();
251            synchronized (this) {
252                while (it.hasNext()) {
253                    clientMap.put(it.next(), f);
254                }
255            }
256        }
257    
258        public void removeClient(String host) {
259            Set names = hostnamePossibilities(host);
260            Iterator it = names.iterator();
261            synchronized (this) {
262                while (it.hasNext()) {
263                    clientMap.remove(it.next());
264                }
265            }
266        }
267    
268        public synchronized void removeClient(SocketFactory sf) {
269            Iterator it = clientMap.entrySet().iterator();
270            while (it.hasNext()) {
271                Map.Entry entry = (Map.Entry) it.next();
272                Object o = entry.getValue();
273                if (sf.equals(o)) {
274                    it.remove();
275                }
276            }
277        }
278    
279        private Set hostnamePossibilities(String host) {
280            host = host != null ? host.toLowerCase().trim() : "";
281            if ("".equals(host)) {
282                return Collections.EMPTY_SET;
283            }
284            TreeSet names = new TreeSet();
285            names.add(host);
286            InetAddress[] addresses;
287            try {
288                // If they gave us "hostname.com", this will give us the various
289                // IP addresses:
290                addresses = InetAddress.getAllByName(host);
291                for (int i = 0; i < addresses.length; i++) {
292                    String name1 = addresses[i].getHostName();
293                    String name2 = addresses[i].getHostAddress();
294                    names.add(name1.trim().toLowerCase());
295                    names.add(name2.trim().toLowerCase());
296                }
297            }
298            catch (UnknownHostException uhe) {
299                /* oh well, nothing found, nothing to add for this client */
300            }
301    
302            try {
303                host = InetAddress.getByName(host).getHostAddress();
304    
305                // If they gave us "1.2.3.4", this will hopefully give us
306                // "hostname.com" so that we can then try and find any other
307                // IP addresses associated with that name.
308                host = InetAddress.getByName(host).getHostName();
309                names.add(host.trim().toLowerCase());
310                addresses = InetAddress.getAllByName(host);
311                for (int i = 0; i < addresses.length; i++) {
312                    String name1 = addresses[i].getHostName();
313                    String name2 = addresses[i].getHostAddress();
314                    names.add(name1.trim().toLowerCase());
315                    names.add(name2.trim().toLowerCase());
316                }
317            }
318            catch (UnknownHostException uhe) {
319                /* oh well, nothing found, nothing to add for this client */
320            }
321            return names;
322        }
323    
324        private void trustOurself()
325            throws GeneralSecurityException, IOException {
326            if (defaultClient == null || sslServer == null) {
327                return;
328            }
329            boolean clientIsCommonsSSL = defaultClient instanceof SSLClient;
330            boolean serverIsCommonsSSL = sslServer instanceof SSLServer;
331            if (clientIsCommonsSSL && serverIsCommonsSSL) {
332                SSLClient c = (SSLClient) defaultClient;
333                SSLServer s = (SSLServer) sslServer;
334                trustEachOther(c, s);
335            }
336        }
337    
338        private void trustEachOther(SSLClient client, SSLServer server)
339            throws GeneralSecurityException, IOException {
340            if (client != null && server != null) {
341                // Our own client should trust our own server.
342                X509Certificate[] certs = server.getAssociatedCertificateChain();
343                if (certs != null && certs[0] != null) {
344                    TrustMaterial tm = new TrustMaterial(certs[0]);
345                    client.addTrustMaterial(tm);
346                }
347    
348                // Our own server should trust our own client.
349                certs = client.getAssociatedCertificateChain();
350                if (certs != null && certs[0] != null) {
351                    TrustMaterial tm = new TrustMaterial(certs[0]);
352                    server.addTrustMaterial(tm);
353                }
354            }
355        }
356    
357        public ServerSocketFactory getServer() { return sslServer; }
358    
359        public SocketFactory getDefaultClient() { return defaultClient; }
360    
361        public synchronized SocketFactory getClient(String host) {
362            host = host != null ? host.trim().toLowerCase() : "";
363            return (SocketFactory) clientMap.get(host);
364        }
365    
366        public synchronized ServerSocket createServerSocket(int port)
367            throws IOException {
368            // Re-use existing ServerSocket if possible.
369            if (port == 0) {
370                port = anonymousPort;
371            }
372            Integer key = new Integer(port);
373            ServerSocket ss = (ServerSocket) serverSockets.get(key);
374            if (ss == null || ss.isClosed()) {
375                if (ss != null && ss.isClosed()) {
376                    System.out.println("found closed server on port: " + port);
377                }
378                log.debug("commons-ssl RMI server-socket: listening on port " + port);
379                ss = sslServer.createServerSocket(port);
380                serverSockets.put(key, ss);
381            }
382            return ss;
383        }
384    
385        public Socket createSocket(String host, int port)
386            throws IOException {
387            host = host != null ? host.trim().toLowerCase() : "";
388            InetAddress local = null;
389            String bindAddress = localBindAddress;
390            if (bindAddress == null) {
391                bindAddress = System.getProperty(RMI_HOSTNAME_KEY);
392                if (bindAddress != null) {
393                    local = InetAddress.getByName(bindAddress);
394                    if (!local.isLoopbackAddress()) {
395                        String ip = local.getHostAddress();
396                        Set myInternetIps = getMyInternetFacingIPs();
397                        if (!myInternetIps.contains(ip)) {
398                            log.warn("Cannot bind to " + ip + " since it doesn't exist on this machine.");
399                            // Not going to be able to bind as this.  Our RMI_HOSTNAME_KEY
400                            // must be set to some kind of proxy in front of us.  So we
401                            // still want to use it, but we can't bind to it.
402                            local = null;
403                            bindAddress = null;
404                        }
405                    }
406                }
407            }
408            if (bindAddress == null) {
409                // Our last resort - let's make sure we at least use something that's
410                // internet facing!
411                bindAddress = getMyDefaultIP();
412            }
413            if (local == null && bindAddress != null) {
414                local = InetAddress.getByName(bindAddress);
415                localBindAddress = local.getHostName();
416            }
417    
418            SocketFactory sf;
419            synchronized (this) {
420                sf = (SocketFactory) clientMap.get(host);
421            }
422            if (sf == null) {
423                sf = defaultClient;
424            }
425    
426            Socket s = null;
427            SSLSocket ssl = null;
428            int soTimeout = Integer.MIN_VALUE;
429            IOException reasonForPlainSocket = null;
430            boolean tryPlain = false;
431            try {
432                s = sf.createSocket(host, port, local, 0);
433                soTimeout = s.getSoTimeout();
434                if (!(s instanceof SSLSocket)) {
435                    // Someone called setClient() or setDefaultClient() and passed in
436                    // a plain socket factory.  Okay, nothing to see, move along.
437                    return s;
438                } else {
439                    ssl = (SSLSocket) s;
440                }
441    
442                // If we don't get the peer certs in 15 seconds, revert to plain
443                // socket.
444                ssl.setSoTimeout(15000);
445                ssl.getSession().getPeerCertificates();
446    
447                // Everything worked out okay, so go back to original soTimeout.
448                ssl.setSoTimeout(soTimeout);
449                return ssl;
450            }
451            catch (IOException ioe) {
452                // SSL didn't work.  Let's analyze the IOException to see if maybe
453                // we're accidentally attempting to talk to a plain-socket RMI
454                // server.
455                Throwable t = ioe;
456                while (!tryPlain && t != null) {
457                    tryPlain = tryPlain || t instanceof EOFException;
458                    tryPlain = tryPlain || t instanceof InterruptedIOException;
459                    tryPlain = tryPlain || t instanceof SSLProtocolException;
460                    t = t.getCause();
461                }
462                if (!tryPlain && ioe instanceof SSLPeerUnverifiedException) {
463                    try {
464                        if (ssl != null) {
465                            ssl.startHandshake();
466                        }
467                    }
468                    catch (IOException ioe2) {
469                        // Stacktrace from startHandshake() will be more descriptive
470                        // then the one we got from getPeerCertificates().
471                        ioe = ioe2;
472                        t = ioe2;
473                        while (!tryPlain && t != null) {
474                            tryPlain = tryPlain || t instanceof EOFException;
475                            tryPlain = tryPlain || t instanceof InterruptedIOException;
476                            tryPlain = tryPlain || t instanceof SSLProtocolException;
477                            t = t.getCause();
478                        }
479                    }
480                }
481                if (!tryPlain) {
482                    log.debug("commons-ssl RMI-SSL failed: " + ioe);
483                    throw ioe;
484                } else {
485                    reasonForPlainSocket = ioe;
486                }
487            }
488            finally {
489                // Some debug logging:
490                boolean isPlain = tryPlain || (s != null && ssl == null);
491                String socket = isPlain ? "RMI plain-socket " : "RMI ssl-socket ";
492                String localIP = local != null ? local.getHostAddress() : "ANY";
493                StringBuffer buf = new StringBuffer(64);
494                buf.append(socket);
495                buf.append(localIP);
496                buf.append(" --> ");
497                buf.append(host);
498                buf.append(":");
499                buf.append(port);
500                log.debug(buf.toString());
501            }
502    
503            // SSL didn't work.  Remote server either timed out, or sent EOF, or
504            // there was some kind of SSLProtocolException.  (Any other problem
505            // would have caused an IOException to be thrown, so execution wouldn't
506            // have made it this far).  Maybe plain socket will work in these three
507            // cases.
508            sf = plainClient;
509            s = JavaImpl.connect(null, sf, host, port, local, 0, 15000, null);
510            if (soTimeout != Integer.MIN_VALUE) {
511                s.setSoTimeout(soTimeout);
512            }
513    
514            try {
515                // Plain socket worked!  Let's remember that for next time an RMI call
516                // against this host happens.
517                setClient(host, plainClient);
518                String msg = "RMI downgrading from SSL to plain-socket for " + host + " because of " + reasonForPlainSocket;
519                log.warn(msg, reasonForPlainSocket);
520            }
521            catch (GeneralSecurityException gse) {
522                throw new RuntimeException("can't happen because we're using plain socket", gse);
523                // won't happen because we're using plain socket, not SSL.
524            }
525    
526            return s;
527        }
528    
529    
530        public static String getMyDefaultIP() {
531            String anInternetIP = "64.111.122.211";
532            String ip = null;
533            try {
534                DatagramSocket dg = new DatagramSocket();
535                dg.setSoTimeout(250);
536                // 64.111.122.211 is juliusdavies.ca.
537                // This code doesn't actually send any packets (so no firewalls can
538                // get in the way).  It's just a neat trick for getting our
539                // internet-facing interface card.
540                InetAddress addr = InetAddress.getByName(anInternetIP);
541                dg.connect(addr, 12345);
542                InetAddress localAddr = dg.getLocalAddress();
543                ip = localAddr.getHostAddress();
544                // log.debug( "Using bogus UDP socket (" + anInternetIP + ":12345), I think my IP address is: " + ip );
545                dg.close();
546                if (localAddr.isLoopbackAddress() || "0.0.0.0".equals(ip)) {
547                    ip = null;
548                }
549            }
550            catch (IOException ioe) {
551                log.debug("Bogus UDP didn't work: " + ioe);
552            }
553            return ip;
554        }
555    
556        public static SortedSet getMyInternetFacingIPs() throws SocketException {
557            TreeSet set = new TreeSet();
558            Enumeration en = NetworkInterface.getNetworkInterfaces();
559            while (en.hasMoreElements()) {
560                NetworkInterface ni = (NetworkInterface) en.nextElement();
561                Enumeration en2 = ni.getInetAddresses();
562                while (en2.hasMoreElements()) {
563                    InetAddress addr = (InetAddress) en2.nextElement();
564                    if (!addr.isLoopbackAddress()) {
565                        String ip = addr.getHostAddress();
566                        String reverse = addr.getHostName();
567                        // IP:
568                        set.add(ip);
569                        // Reverse-Lookup:
570                        set.add(reverse);
571    
572                    }
573                }
574            }
575            return set;
576        }
577    
578    }