Nico Golde:
[apps/madmutt.git] / mutt_ssl.c
index 1f847a4..a56147d 100644 (file)
@@ -16,8 +16,9 @@
  *     Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111, USA.
  */
 
-/* for SSL NO_* defines */
-#include "config.h"
+#if HAVE_CONFIG_H
+# include "config.h"
+#endif
 
 #include <openssl/ssl.h>
 #include <openssl/x509.h>
@@ -68,13 +69,15 @@ sslsockdata;
 /* local prototypes */
 int ssl_init (void);
 static int add_entropy (const char *file);
-static int ssl_check_certificate (sslsockdata * data);
 static int ssl_socket_read (CONNECTION* conn, char* buf, size_t len);
 static int ssl_socket_write (CONNECTION* conn, const char* buf, size_t len);
 static int ssl_socket_open (CONNECTION * conn);
 static int ssl_socket_close (CONNECTION * conn);
 static int tls_close (CONNECTION* conn);
-int ssl_negotiate (sslsockdata*);
+static int ssl_check_certificate (sslsockdata * data);
+static void ssl_get_client_cert(sslsockdata *ssldata, CONNECTION *conn);
+static int ssl_passwd_cb(char *buf, int size, int rwflag, void *userdata);
+static int ssl_negotiate (sslsockdata*);
 
 /* mutt_ssl_starttls: Negotiate TLS over an already opened connection.
  *   TODO: Merge this code better with ssl_socket_open. */
@@ -94,6 +97,8 @@ int mutt_ssl_starttls (CONNECTION* conn)
     goto bail_ssldata;
   }
 
+  ssl_get_client_cert(ssldata, conn);
+
   if (! (ssldata->ssl = SSL_new (ssldata->ctx)))
   {
     dprint (1, (debugfile, "mutt_ssl_starttls: Error allocating SSL\n"));
@@ -111,9 +116,9 @@ int mutt_ssl_starttls (CONNECTION* conn)
 
   /* hmm. watch out if we're starting TLS over any method other than raw. */
   conn->sockdata = ssldata;
-  conn->read = ssl_socket_read;
-  conn->write = ssl_socket_write;
-  conn->close = tls_close;
+  conn->conn_read = ssl_socket_read;
+  conn->conn_write = ssl_socket_write;
+  conn->conn_close = tls_close;
 
   conn->ssf = SSL_CIPHER_get_bits (SSL_get_current_cipher (ssldata->ssl),
     &maxbits);
@@ -228,14 +233,14 @@ int ssl_socket_setup (CONNECTION * conn)
 {
   if (ssl_init() < 0)
   {
-    conn->open = ssl_socket_open_err;
+    conn->conn_open = ssl_socket_open_err;
     return -1;
   }
 
-  conn->open   = ssl_socket_open;
-  conn->read   = ssl_socket_read;
-  conn->write  = ssl_socket_write;
-  conn->close  = ssl_socket_close;
+  conn->conn_open      = ssl_socket_open;
+  conn->conn_read      = ssl_socket_read;
+  conn->conn_write     = ssl_socket_write;
+  conn->conn_close     = ssl_socket_close;
 
   return 0;
 }
@@ -279,6 +284,8 @@ static int ssl_socket_open (CONNECTION * conn)
     SSL_CTX_set_options(data->ctx, SSL_OP_NO_SSLv3);
   }
 
+  ssl_get_client_cert(data, conn);
+
   data->ssl = SSL_new (data->ctx);
   SSL_set_fd (data->ssl, conn->fd);
 
@@ -296,7 +303,7 @@ static int ssl_socket_open (CONNECTION * conn)
 
 /* ssl_negotiate: After SSL state has been initialised, attempt to negotiate
  *   SSL over the wire, including certificate checks. */
-int ssl_negotiate (sslsockdata* ssldata)
+static int ssl_negotiate (sslsockdata* ssldata)
 {
   int err;
   const char* errmsg;
@@ -315,7 +322,7 @@ int ssl_negotiate (sslsockdata* ssldata)
       errmsg = _("I/O error");
       break;
     case SSL_ERROR_SSL:
-      errmsg = _("unspecified protocol error");
+      errmsg = ERR_error_string (ERR_get_error (), NULL);
       break;
     default:
       errmsg = _("unknown error");
@@ -366,9 +373,9 @@ static int tls_close (CONNECTION* conn)
   int rc;
 
   rc = ssl_socket_close (conn);
-  conn->read = raw_socket_read;
-  conn->write = raw_socket_write;
-  conn->close = raw_socket_close;
+  conn->conn_read = raw_socket_read;
+  conn->conn_write = raw_socket_write;
+  conn->conn_close = raw_socket_close;
 
   return rc;
 }
@@ -403,7 +410,7 @@ static void x509_fingerprint (char *s, int l, X509 * cert)
 
   if (!X509_digest (cert, EVP_md5 (), md, &n))
   {
-    snprintf (s, l, _("[unable to calculate]"));
+    snprintf (s, l, "%s", _("[unable to calculate]"));
   }
   else
   {
@@ -411,7 +418,7 @@ static void x509_fingerprint (char *s, int l, X509 * cert)
     {
       char ch[8];
       snprintf (ch, 8, "%02X%s", md[j], (j % 2 ? " " : ""));
-      strncat (s, ch, l);
+      safe_strcat (s, l, ch);
     }
   }
 }
@@ -598,7 +605,7 @@ static int ssl_check_certificate (sslsockdata * data)
   }
 
   row++;
-  snprintf (menu->dialog[row++], SHORT_STRING, _("This certificate is valid"));
+  snprintf (menu->dialog[row++], SHORT_STRING, "%s", _("This certificate is valid"));
   snprintf (menu->dialog[row++], SHORT_STRING, _("   from %s"), 
       asn1time_to_string (X509_get_notBefore (data->cert)));
   snprintf (menu->dialog[row++], SHORT_STRING, _("     to %s"), 
@@ -623,9 +630,9 @@ static int ssl_check_certificate (sslsockdata * data)
   
   helpstr[0] = '\0';
   mutt_make_help (buf, sizeof (buf), _("Exit  "), MENU_GENERIC, OP_EXIT);
-  strncat (helpstr, buf, sizeof (helpstr));
+  safe_strcat (helpstr, sizeof (helpstr), buf);
   mutt_make_help (buf, sizeof (buf), _("Help"), MENU_GENERIC, OP_HELP);
-  strncat (helpstr, buf, sizeof (helpstr));
+  safe_strcat (helpstr, sizeof (helpstr), buf);
   menu->help = helpstr;
 
   done = 0;
@@ -667,3 +674,31 @@ static int ssl_check_certificate (sslsockdata * data)
   mutt_menuDestroy (&menu);
   return (done == 2);
 }
+
+static void ssl_get_client_cert(sslsockdata *ssldata, CONNECTION *conn)
+{
+  if (SslClientCert)
+  {
+    dprint (2, (debugfile, "Using client certificate %s\n", SslClientCert));
+    SSL_CTX_set_default_passwd_cb_userdata(ssldata->ctx, &conn->account);
+    SSL_CTX_set_default_passwd_cb(ssldata->ctx, ssl_passwd_cb);
+    SSL_CTX_use_certificate_file(ssldata->ctx, SslClientCert, SSL_FILETYPE_PEM);
+    SSL_CTX_use_PrivateKey_file(ssldata->ctx, SslClientCert, SSL_FILETYPE_PEM);
+  }
+}
+
+static int ssl_passwd_cb(char *buf, int size, int rwflag, void *userdata)
+{
+  ACCOUNT *account = (ACCOUNT*)userdata;
+
+  if (mutt_account_getuser (account))
+    return 0;
+
+  dprint (2, (debugfile, "ssl_passwd_cb: getting password for %s@%s:%u\n",
+             account->user, account->host, account->port));
+  
+  if (mutt_account_getpass (account))
+    return 0;
+
+  return snprintf(buf, size, "%s", account->pass);
+}