Python 3: How to log SSL handshake errors from server side - python

I'm using HTTPServer for a basic HTTP server using SSL. I would like to log any time a client initiates an SSL Handshake (or perhaps any time a socket is accepted?) along with any associated errors. I imagine that I'd need to extend some class or override some method, but I'm not sure which or how to properly go about implementing it. I'd greatly appreciate any help. Thanks in advance!
Trimmed down sample code:
from http.server import BaseHTTPRequestHandler, HTTPServer
from socketserver import ThreadingMixIn
from threading import Thread
import ssl
import logging
import sys
class MyHTTPHandler(BaseHTTPRequestHandler):
def log_message(self, format, *args):
logger.info("%s - - %s" % (self.address_string(), format%args))
def do_GET(self):
self.send_response(200)
self.end_headers()
self.wfile.write('test'.encode("utf-8"))
class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
pass
logger = logging.getLogger('myserver')
handler = logging.FileHandler('server.log')
formatter = logging.Formatter('[%(asctime)s] %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
server = ThreadedHTTPServer(('', 443), MyHTTPHandler)
server.socket = ssl.wrap_socket (server.socket, keyfile='server.key', certfile='server.crt', server_side=True, cert_reqs=ssl.CERT_REQUIRED, ca_certs='client.crt')
Thread(target=server.serve_forever).start()
try:
quitcheck = input("Type 'quit' at any time to quit.\n")
if quitcheck == "quit":
server.shutdown()
except (KeyboardInterrupt) as error:
server.shutdown()

From looking at the ssl module, most of the relevant magic happens in the SSLSocket class.
ssl.wrap_socket() is just a tiny convenience function that basically serves as a factory for an SSLSocket with some reasonable defaults, and wraps an existing socket.
Unfortunately, SSLSocket does not seem to do any logging of its own, so there's no easy way to turn up a logging level, set a debug flag or register any handlers.
So what you can do instead is to subclass SSLSocket, override the methods you're interested in with your own that do some logging, and create and use your own wrap_socket helper function.
Subclassing SSLSocket
First, copy over ssl.wrap_socket() from your Python's .../lib/python2.7/ssl.py into your code. (Make sure that any code you copy and modify actually comes from the Python installation you're using - the code may have changed between different Python versions).
Now adapt your copy of wrap_socket() to
create an instance of a LoggingSSLSocket (which we'll implement below) instead of SSLSocket
and use constants from the ssl module where necessary (ssl.CERT_NONE and ssl.PROTOCOL_SSLv23 in this example)
def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
ciphers=None):
return LoggingSSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers)
Now change your line
server.socket = ssl.wrap_socket (server.socket, ...)
to
server.socket = wrap_socket(server.socket, ...)
in order to use your own wrap_socket().
Now for subclassing SSLSocket. Create a class LoggingSSLSocket that subclasses SSLSocket by adding the following to your code:
class LoggingSSLSocket(ssl.SSLSocket):
def accept(self, *args, **kwargs):
logger.debug('Accepting connection...')
result = super(LoggingSSLSocket, self).accept(*args, **kwargs)
logger.debug('Done accepting connection.')
return result
def do_handshake(self, *args, **kwargs):
logger.debug('Starting handshake...')
result = super(LoggingSSLSocket, self).do_handshake(*args, **kwargs)
logger.debug('Done with handshake.')
return result
Here we override the accept() and do_handshake() methods of ssl.SSLSocket - everything else stays the same, since the class inherits from SSLSocket.
Generic approach to overriding methods
I used a particular pattern for overriding these methods in order to make it easier to apply to pretty much any method you'll ever override:
def methodname(self, *args, **kwargs):
*args, **kwargs makes sure our method accepts any number of positional and keyword arguments, if any. accept doesn't actually take any of those, but it still works because of Python's packing / unpacking of argument lists.
logger.debug('Before call to superclass method')
Here you get the opportunity to do your own thing before calling the superclass' method.
result = super(LoggingSSLSocket, self).methodname(*args, **kwargs)
This is the actual call to the superclass' method. See the docs on super() for details on how this works, but it basically calls .methodname() on LoggingSSLSocket's superclass (SSLSocket). Because we pass *args, **kwargs to the method, we just pass on any positional and keyword arguments our method got - we don't even need to know what they are, the method signatures will always match.
Because some methods (like accept()) will return a result, we store that result and return it at the end of our method, just before doing our post-call work:
logger.debug('After call.')
return result
Logging more details
If you want to include more information in your logging statements, you'll likely have to completely overwrite the respective methods. So copy them over and modify them as required, and make sure you satisfy any missing imports.
Here's an example for accept() that includes the IP address and local port of the client that's trying to connect:
def accept(self):
"""Accepts a new connection from a remote client, and returns
a tuple containing that new connection wrapped with a server-side
SSL channel, and the address of the remote client."""
newsock, addr = socket.accept(self)
logger.debug("Accepting connection from '%s'..." % (addr, ))
newsock = self.context.wrap_socket(newsock,
do_handshake_on_connect=self.do_handshake_on_connect,
suppress_ragged_eofs=self.suppress_ragged_eofs,
server_side=True)
logger.debug('Done accepting connection.')
return newsock, addr
(Make sure to include from socket import socket in your imports at the top of your code - refer to the ssl module's imports to determine where you need to import missing names from if you get a NameError. An good text editor with PyFlakes configured is very helpful in pointing those missing imports out to you).
This method will result in logging output like this:
[2014-10-24 22:01:40,299] Accepting connection from '('127.0.0.1', 64152)'...
[2014-10-24 22:01:40,300] Done accepting connection.
[2014-10-24 22:01:40,301] Accepting connection from '('127.0.0.1', 64153)'...
[2014-10-24 22:01:40,302] Done accepting connection.
[2014-10-24 22:01:40,306] Accepting connection from '('127.0.0.1', 64155)'...
[2014-10-24 22:01:40,307] Done accepting connection.
[2014-10-24 22:01:40,308] 127.0.0.1 - - "GET / HTTP/1.1" 200 -
Because it involves quite a few changes scattered all over the place, here's a gist containing all the changes to your example code.

Related

Write a unit test with pytest to test a socket

I wrote a small server chat that does very basic things and I would like to write the tests around it. Unfortunately I quite lost regarding. I would need some help to get on the right tracks.
I have a class called Server() and it contains a method called bind_socket(). I would like to write unit test (preferably using pytest) to test the following method:
class Server(Threading.Thread):
""" Server side class
Instanciate a server in a thread.
"""
MAX_WAITING_CONNECTIONS = 10
def __init__(self, host='localhost', port=10000):
""" Constructor of the Server class.
Initialize the instance in a thread.
Args:
host (str): Host to which to connect (default=localhost)
port (int): Port on which to connect (default=10000)
"""
threading.Thread.__init__(self)
self.host = host
self.port = port
self.connections = []
self.running = True
def bind_socket(self, ip=socket.AF_INET, protocol=socket.SOCK_STREAM):
self.server_socket = socket.socket(ip, protocol)
self.server_socket.bind((self.host, self.port))
self.server_socket.listen(self.MAX_WAITING_CONNECTIONS)
self.connections.append(self.server_socket)
I'm wondering what is the best way to write a test for this method as it doesn't return anything. Should I mock it and try to return the number of of call of socket(), bind(), listen() and append() or is it the wrong way to do proceed? I'm quite lost on that, I did many tests either with pytest and unittest, watch conferences and read articles and I still don't have anything working.
Some explanation and/or examples would be greatly appreciated.
Thanks a lot
For each line of bind_socket you should ask yourself the questions:
What if this line didn't exist
(for conditionals... I know you don't have any here) What if this condition was the other way around
Can this line raise exceptions.
You want your tests to cover all these eventualities.
For example, socket.bind can raise an exception if it's already bound, or socket.listen can raise an exception. Do you close the socket afterwards?

Memcache client with connection pool for Python?

python-memcached memcache client is written in a way where each thread gets its own connection. This makes python-memcached code simple, which is nice, but presents a problem if your application has hundreds or thousands of threads (or if you run lots of applications), because you will quickly run out of available connections in memcache.
Typically this kind of problem is solved by using a connection pool, and indeed the Java memcache libraries I have seen implement connection pooling. After reading the documentation for various Python memcache libraries it seems the only one offering connection pool is pylibmc, but it has two problems for me: it is not pure Python, and it does not seem to have a timeout for reserving a client from the pool. While not being pure Python is perhaps not a deal breaker, not having a timeout certainly is. It is also not clear how those pools would work with for example dogpile.cache.
Preferably I would like to find a pure Python memcache client with connection pooling that would work with dogpile.cache, but I am open to other suggestions as well. I'd rather avoid changing the application logic, though (like pushing all memcache operations into fewer background threads).
A coworker came up with an idea that seems to work well enough for our use case, so sharing that here. The basic idea is that you create the number of memcache clients you want to use up front, put them in a queue, and whenever you need a memcache client you pull one from the queue. Due to Queue.Queue get() method having optional timeout parameter, you can also handle the case where you can't get a client in time. You also need to deal with the use of threading.local in memcache client.
Here is how it could work in code (note that I haven't actually run this exact version so there might be some issues, but this should give you an idea if the textual description did not make sense to you):
import Queue
import memcache
# See http://stackoverflow.com/questions/9539052/python-dynamically-changing-base-classes-at-runtime-how-to
# Don't inherit client from threading.local so that we can reuse clients in
# different threads
memcache.Client = type('Client', (object,), dict(memcache.Client.__dict__))
# Client.__init__ references local, so need to replace that, too
class Local(object): pass
memcache.local = Local
class PoolClient(object):
'''Pool of memcache clients that has the same API as memcache.Client'''
def __init__(self, pool_size, pool_timeout, *args, **kwargs):
self.pool_timeout = pool_timeout
self.queue = Queue.Queue()
for _i in range(pool_size):
self.queue.put(memcache.Client(*args, **kwargs))
def __getattr__(self, name):
return lambda *args, **kw: self._call_client_method(name, *args, **kw)
def _call_client_method(self, name, *args, **kwargs):
try:
client = self.queue.get(timeout=self.pool_timeout)
except Queue.Empty:
return
try:
return getattr(client, name)(*args, **kwargs)
finally:
self.queue.put(client)
Many thank to #Heikki Toivenen for providing ideas to the problem! However, I'm not sure how to call the get() method exactly in order to use a memcache client in the PoolClient. Direct calling of get() method with arbitrary name gives AttributeError or MemcachedKeyNoneError.
By combining #Heikki Toivonen's and pylibmc's solution to the problem, I came up with the following code for the problem and posted here for the convenience of future users (I have debugged this code and it should be ready to run):
import Queue, memcache
from contextlib import contextmanager
memcache.Client = type('Client', (object,), dict(memcache.Client.__dict__))
# Client.__init__ references local, so need to replace that, too
class Local(object): pass
memcache.local = Local
class PoolClient(object):
'''Pool of memcache clients that has the same API as memcache.Client'''
def __init__(self, pool_size, pool_timeout, *args, **kwargs):
self.pool_timeout = pool_timeout
self.queue = Queue.Queue()
for _i in range(pool_size):
self.queue.put(memcache.Client(*args, **kwargs))
print "pool_size:", pool_size, ", Queue_size:", self.queue.qsize()
#contextmanager
def reserve( self ):
''' Reference: http://sendapatch.se/projects/pylibmc/pooling.html#pylibmc.ClientPool'''
client = self.queue.get(timeout=self.pool_timeout)
try:
yield client
finally:
self.queue.put( client )
print "Queue_size:", self.queue.qsize()
# Intanlise an instance of PoolClient
mc_client_pool = PoolClient( 5, 0, ['127.0.0.1:11211'] )
# Use a memcache client from the pool of memcache client in your apps
with mc_client_pool.reserve() as mc_client:
#do your work here

Python SimpleXMLRPCServer: get user IP and simple authentication

I am trying to make a very simple XML RPC Server with Python that provides basic authentication + ability to obtain the connected user's IP. Let's take the example provided in http://docs.python.org/library/xmlrpclib.html :
import xmlrpclib
from SimpleXMLRPCServer import SimpleXMLRPCServer
def is_even(n):
return n%2 == 0
server = SimpleXMLRPCServer(("localhost", 8000))
server.register_function(is_even, "is_even")
server.serve_forever()
So now, the first idea behind this is to make the user supply credentials and process them before allowing him to use the functions. I need very simple authentication, for example just a code. Right now what I'm doing is to force the user to supply this code in the function call and test it with an if-statement.
The second one is to be able to get the user IP when he calls a function or either store it after he connects to the server.
Moreover, I already have an Apache Server running and it might be simpler to integrate this into it.
What do you think?
This is a related question that I found helpful:
IP address of client in Python SimpleXMLRPCServer?
What worked for me was to grab the client_address in an overridden finish_request method of the server, stash it in the server itself, and then access this in an overridden server _dispatch routine. You might be able to access the server itself from within the method, too, but I was just trying to add the IP address as an automatic first argument to all my method calls. The reason I used a dict was because I'm also going to add a session token and perhaps other metadata as well.
from xmlrpc.server import DocXMLRPCServer
from socketserver import BaseServer
class NewXMLRPCServer( DocXMLRPCServer):
def finish_request( self, request, client_address):
self.client_address = client_address
BaseServer.finish_request( self, request, client_address)
def _dispatch( self, method, params):
metadata = { 'client_address' : self.client_address[ 0] }
newParams = ( metadata, ) + params
return DocXMLRPCServer._dispatch( self, method, metadata)
Note this will BREAK introspection functions like system.listMethods() because that isn't expecting the extra argument. One idea would be to check the method name for "system." and just pass the regular params in that case.

Validating client certificates in PyOpenSSL

I'm writing an app that requires a cert to be installed in the client browser. I've found this in the PyOpenSSL docs for the "Context" object but I can't see anything about how the callback is supposed to validate the cert, only that it should, somehow.
set_verify(mode, callback)
Set the verification flags for this Context object to mode and
specify that callback should be used for verification callbacks.
mode should be one of VERIFY_NONE and VERIFY_PEER. If
VERIFY_PEER is used, mode can be OR:ed with
VERIFY_FAIL_IF_NO_PEER_CERT and VERIFY_CLIENT_ONCE to further
control the behaviour. callback should take five arguments: A
Connection object, an X509 object, and three integer variables,
which are in turn potential error number, error depth and return
code. callback should return true if verification passes and
false otherwise.
I'm telling the Context object where my (self signed) keys are (see below) so I guess I don't understand why that's not enough for the library to check if the cert presented by the client is a valid one. What should one do in this callback function?
class SecureAJAXServer(PlainAJAXServer):
def __init__(self, server_address, HandlerClass):
BaseServer.__init__(self, server_address, HandlerClass)
ctx = SSL.Context(SSL.SSLv23_METHOD)
ctx.use_privatekey_file ('keys/server.key')
ctx.use_certificate_file('keys/server.crt')
ctx.set_session_id("My_experimental_AJAX_Server")
ctx.set_verify( SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT | SSL.VERIFY_CLIENT_ONCE, callback_func )
self.socket = SSL.Connection(ctx, socket.socket(self.address_family, self.socket_type))
self.server_bind()
self.server_activate()
Caveat: Coding for fun here, def not a pro so if my Q reveals my total lameness, naivety and/or fundamental lack of understanding when it comes to SSL please don't be too rough!
Thanks :)
Roger
In the OpenSSL documentation for set_verify(), the key that you care about is the return code:
callback should take five arguments: A Connection object, an X509
object, and three integer variables, which are in turn potential error
number, error depth and return code. callback should return true
if verification passes and false otherwise.
There is a a full working example that shows more or less what you want to do: When are client certificates verified?
Essentially you can ignore the first 4 arguments and just check the value of the return code in the fifth argument like so:
from OpenSSL.SSL import Context, Connection, SSLv23_METHOD
from OpenSSL.SSL import VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE
class SecureAJAXServer(BaseServer):
def verify_callback(connection, x509, errnum, errdepth, ok):
if not ok:
print "Bad Certs"
else:
print "Certs are fine"
return ok
def __init__(self, server_address, HandlerClass):
BaseServer.__init__(self, server_address, HandlerClass)
ctx = Context(SSLv23_METHOD)
ctx.use_privatekey_file ('keys/server.key')
ctx.use_certificate_file('keys/server.crt')
ctx.set_session_id("My_experimental_AJAX_Server")
ctx.set_verify( VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT | VERIFY_CLIENT_ONCE, verify_callback )
self.socket = Connection(ctx, socket.socket(self.address_family, self.socket_type))
self.server_bind()
self.server_activate()
Note: I made one other change which is from OpenSSL.SSL import ... to simplify your code a bit while I was testing it so you don't have the SSL. prefix in front of every import symbol.

Best way to run remote commands thru ssh in Twisted?

I have a twisted application which now needs to monitor processes running on several boxes. The way I manually do is 'ssh and ps', now I'd like my twisted application to do. I have 2 options.
Use paramiko or leverage the power of twisted.conch
I really want to use twisted.conch but my research led me to believe that its primarily intended to create SSHServers and SSHClients. However my requirement is a simple remoteExecute(some_cmd)
I was able to figure out how to do this using paramiko but I didnt want to stickparamiko in my twisted app before looking at how to do this using twisted.conch
Code snippets using twisted on how to run remote_cmds using ssh would be highly appreciated. Thanks.
Followup - Happily, the ticket I referenced below is now resolved. The simpler API will be included in the next release of Twisted. The original answer is still a valid way to use Conch and may reveal some interesting details about what's going on, but from Twisted 13.1 and on, if you just want to run a command and handle it's I/O, this simpler interface will work.
It takes an unfortunately large amount of code to execute a command on an SSH using the Conch client APIs. Conch makes you deal with a lot of different layers, even if you just want sensible boring default behavior. However, it's certainly possible. Here's some code which I've been meaning to finish and add to Twisted to simplify this case:
import sys, os
from zope.interface import implements
from twisted.python.failure import Failure
from twisted.python.log import err
from twisted.internet.error import ConnectionDone
from twisted.internet.defer import Deferred, succeed, setDebugging
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.internet.protocol import Factory, Protocol
from twisted.conch.ssh.common import NS
from twisted.conch.ssh.channel import SSHChannel
from twisted.conch.ssh.transport import SSHClientTransport
from twisted.conch.ssh.connection import SSHConnection
from twisted.conch.client.default import SSHUserAuthClient
from twisted.conch.client.options import ConchOptions
# setDebugging(True)
class _CommandTransport(SSHClientTransport):
_secured = False
def verifyHostKey(self, hostKey, fingerprint):
return succeed(True)
def connectionSecure(self):
self._secured = True
command = _CommandConnection(
self.factory.command,
self.factory.commandProtocolFactory,
self.factory.commandConnected)
userauth = SSHUserAuthClient(
os.environ['USER'], ConchOptions(), command)
self.requestService(userauth)
def connectionLost(self, reason):
if not self._secured:
self.factory.commandConnected.errback(reason)
class _CommandConnection(SSHConnection):
def __init__(self, command, protocolFactory, commandConnected):
SSHConnection.__init__(self)
self._command = command
self._protocolFactory = protocolFactory
self._commandConnected = commandConnected
def serviceStarted(self):
channel = _CommandChannel(
self._command, self._protocolFactory, self._commandConnected)
self.openChannel(channel)
class _CommandChannel(SSHChannel):
name = 'session'
def __init__(self, command, protocolFactory, commandConnected):
SSHChannel.__init__(self)
self._command = command
self._protocolFactory = protocolFactory
self._commandConnected = commandConnected
def openFailed(self, reason):
self._commandConnected.errback(reason)
def channelOpen(self, ignored):
self.conn.sendRequest(self, 'exec', NS(self._command))
self._protocol = self._protocolFactory.buildProtocol(None)
self._protocol.makeConnection(self)
def dataReceived(self, bytes):
self._protocol.dataReceived(bytes)
def closed(self):
self._protocol.connectionLost(
Failure(ConnectionDone("ssh channel closed")))
class SSHCommandClientEndpoint(object):
implements(IStreamClientEndpoint)
def __init__(self, command, sshServer):
self._command = command
self._sshServer = sshServer
def connect(self, protocolFactory):
factory = Factory()
factory.protocol = _CommandTransport
factory.command = self._command
factory.commandProtocolFactory = protocolFactory
factory.commandConnected = Deferred()
d = self._sshServer.connect(factory)
d.addErrback(factory.commandConnected.errback)
return factory.commandConnected
class StdoutEcho(Protocol):
def dataReceived(self, bytes):
sys.stdout.write(bytes)
sys.stdout.flush()
def connectionLost(self, reason):
self.factory.finished.callback(None)
def copyToStdout(endpoint):
echoFactory = Factory()
echoFactory.protocol = StdoutEcho
echoFactory.finished = Deferred()
d = endpoint.connect(echoFactory)
d.addErrback(echoFactory.finished.errback)
return echoFactory.finished
def main():
from twisted.python.log import startLogging
from twisted.internet import reactor
from twisted.internet.endpoints import TCP4ClientEndpoint
# startLogging(sys.stdout)
sshServer = TCP4ClientEndpoint(reactor, "localhost", 22)
commandEndpoint = SSHCommandClientEndpoint("/bin/ls", sshServer)
d = copyToStdout(commandEndpoint)
d.addErrback(err, "ssh command / copy to stdout failed")
d.addCallback(lambda ignored: reactor.stop())
reactor.run()
if __name__ == '__main__':
main()
Some things to note about it:
It uses the new endpoint APIs introduced in Twisted 10.1. It's possible to do this directly on reactor.connectTCP, but I did it as an endpoint to make it more useful; endpoints can be swapped easily without the code that actually asks for a connection knowing.
It does no host key verification at all! _CommandTransport.verifyHostKey is where you would implement that. Take a look at twisted/conch/client/default.py for some hints about what kinds of things you might want to do.
It takes $USER to be the remote username, which you may want to be a parameter.
It probably only works with key authentication. If you want to enable password authentication, you probably need to subclass SSHUserAuthClient and override getPassword to do something.
Almost all of the layers of SSH and Conch are visible here:
_CommandTransport is at the bottom, a plain old protocol that implements the SSH transport protocol. It creates a...
_CommandConnection which implements the SSH connection negotiation parts of the protocol. Once that completes, a...
_CommandChannel is used to talk to a newly opened SSH channel. _CommandChannel does the actual exec to launch your command. Once the channel is opened it creates an instance of...
StdoutEcho, or whatever other protocol you supply. This protocol will get the output from the command you execute, and can write to the command's stdin.
See http://twistedmatrix.com/trac/ticket/4698 for progress in Twisted on supporting this with less code.

Categories

Resources