Read/write values using Ethernet/IP - python

I recently have acquired an ACS Linear Actuator (Tolomatic Stepper) that I am attempting to send data to from a Python application. The device itself communicates using Ethernet/IP protocol.
I have installed the library cpppo via pip. When I issue a command
in an attempt to read status of the device, I get None back. Examining the
communication with Wireshark, I see that it appears like it is
proceeding correctly however I notice a response from the device indicating:
Service not supported.
Example of the code I am using to test reading an "Input Assembly":
from cpppo.server.enip import client
HOST = "192.168.1.100"
TAGS = ["#4/100/3"]
with client.connector(host=HOST) as conn:
for index, descr, op, reply, status, value in conn.synchronous(
operations=client.parse_operations(TAGS)):
print(": %20s: %s" % (descr, value))
I am expecting to get a "input assembly" read but it does not appear to be
working that way. I imagine that I am missing something as this is the first
time I have attempted Ethernet/IP communication.
I am not sure how to proceed or what I am missing about Ethernet/IP that may make this work correctly.

clutton -- I'm the author of the cpppo module.
Sorry for the delayed response. We only recently implemented the ability to communicate with simple (non-routing) CIP devices. The ControlLogix/CompactLogix controllers implement an expanded set of EtherNet/IP CIP capability, something that most simple CIP devices do not. Furthermore, they typically also do not implement the *Logix "Read Tag" request; you have to struggle by with the basic "Get Attribute Single/All" requests -- which just return raw, 8-bit data. It is up to you to turn that back into a CIP REAL, INT, DINT, etc.
In order to communicate with your linear actuator, you will need to disable these enhanced encapsulations, and use "Get Attribute Single" requests. This is done by specifying an empty route_path=[] and send_path='', when you parse your operations, and to use cpppo.server.enip.getattr's attribute_operations (instead of cpppo.server.enip.client's parse_operations):
from cpppo.server.enip import client
from cpppo.server.enip.getattr import attribute_operations
HOST = "192.168.1.100"
TAGS = ["#4/100/3"]
with client.connector(host=HOST) as conn:
for index, descr, op, reply, status, value in conn.synchronous(
operations=attribute_operations(
TAGS, route_path=[], send_path='' )):
print(": %20s: %s" % (descr, value))
That should do the trick!
We are in the process of rolling out a major update to the cpppo module, so clone the https://github.com/pjkundert/cpppo.git Git repo, and checkout the feature-list-identity branch, to get early access to much better APIs for accessing raw data from these simple devices, for testing. You'll be able to use cpppo to convert the raw data into CIP REALs, instead of having to do it yourself...
...
With Cpppo >= 3.9.0, you can now use much more powerful cpppo.server.enip.get_attribute 'proxy' and 'proxy_simple' interfaces to routing CIP devices (eg. ControlLogix, Compactlogix), and non-routing "simple" CIP devices (eg. MicroLogix, PowerFlex, etc.):
$ python
>>> from cpppo.server.enip.get_attribute import proxy_simple
>>> product_name, = proxy_simple( '10.0.1.2' ).read( [('#1/1/7','SSTRING')] )
>>> product_name
[u'1756-L61/C LOGIX5561']
If you want regular updates, use cpppo.server.enip.poll:
import logging
import sys
import time
import threading
from cpppo.server.enip import poll
from cpppo.server.enip.get_attribute import proxy_simple as device
params = [('#1/1/1','INT'),('#1/1/7','SSTRING')]
# If you have an A-B PowerFlex, try:
# from cpppo.server.enip.ab import powerflex_750_series as device
# parms = [ "Motor Velocity", "Output Current" ]
hostname = '10.0.1.2'
values = {} # { <parameter>: <value>, ... }
poller = threading.Thread(
target=poll.poll, args=(device,), kwargs={
'address': (hostname, 44818),
'cycle': 1.0,
'timeout': 0.5,
'process': lambda par,val: values.update( { par: val } ),
'params': params,
})
poller.daemon = True
poller.start()
# Monitor the values dict (updated in another Thread)
while True:
while values:
logging.warning( "%16s == %r", *values.popitem() )
time.sleep( .1 )
And, Voila! You now have regularly updating parameter names and values in your 'values' dict. See the examples in cpppo/server/enip/poll_example*.py for further details, such as how to report failures, control exponential back-off of connection retries, etc.
Version 3.9.5 has recently been released, which has support for writing to CIP Tags and Attributes, using the cpppo.server.enip.get_attribute proxy and proxy_simple APIs. See cpppo/server/enip/poll_example_many_with_write.py

hope this is obvious, but accessing HOST = "192.168.1.100" will only be possible from a system located on the subnet 192.168.1.*

Related

Forced Slowdown of Multiprocessing Generation vs OSError Too Many Open Files

I have the following code. In testing, I found that when I get several hundred concurrent child processes (somewhere around 400?), I get "OSError Too Many Open Files". Any idea why?
I can solve the problem with the time.sleep(.005) call, but I shouldn't have to.
This is a part of a larger program. A typical call will set a server string, token string, and a list of many thousands of devices. For the REST API call used, the server can only handle a single device at a time. In testing, this resulted in a 20 min execution time, but indications are that using a multiprocessing approach can reduce it to around 30 sec.
import urllib, requests, json, sys, pprint, time, multiprocessing as mp
assert sys.version_info >= (3, 6), "Must use Python 3.6+"
###########################
### handler function for multiprocessing worker
###########################
def getAttributesOneDevice(server, device, token, q):
"""Handler function for getting a single device"""
serverURL = server + "/ServicesAPI/API/V1/CMDB/Devices/Attributes"
headers = { "Accept" : "application/json",
"Content-Type" : "application/json",
"token" : token }
query = { "hostname" : device }
response = requests.get(serverURL, headers = headers, params = query, verify = False)
q.put(response.json())
# end getAttributesOneDevice()
def GetDeviceAttributes(server = "", token = "", devices = []):
"""
See this URL for explanation of what this function does
https://github.com/NetBrainAPI/NetBrain-REST-API-V8.02/blob/master
/REST APIs Documentation/Devices Management
/Get Device Attributes API.md
To summarize the URL: will acquire detailed device attributes for a single
device.
This subroutine therefore queries for all devices provided, and assemble the
results into a single list of dicts.
Server queries are relatively expensive. A single query is not a big deal,
but accumulated across a massive device list and this can take excessive
time to execute (20min, etc). Therefor, this procedure is parallelized
through multi-processing to complete in a reasonable amount of time.
'server' should be a string that is just the http(s)://<FQDN>. Do not
incude the trailing '/'.
'token' should be an authentication token that was generated by
GetLoginToken and SetWorkingDomain modules in this directory.
'devices' should be a list of strings, where each entry is a device.
return a single dictionary:
"Table" a list of dicts, each dict the detailed attributes of a device
"Missed" a list of devices that had no result
Note that failure to capture a device is distinct from function failure.
"""
resultsTable = []
MissedDevices = []
procList = []
for dev in devices:
q = mp.Queue()
proc = mp.Process(target=getAttributesOneDevice,
args=(server, dev, token, q))
proc.start()
procList += [ {"proc" : proc, "dev" : dev, "queue" : q} ]
# If I don't do this as I'm going, I *always* get "OSError too many open files"
updatedProcList = []
for proc in procList:
if proc["proc"].is_alive():
updatedProcList += [proc]
else:
# kill zombies
if proc["queue"].empty():
MissedDevices += [ proc["dev"] ]
else:
queueData = proc["queue"].get()
resultsTable += [ queueData ]
while not proc["queue"].empty():
# drain whatever's left before we closeout the process
proc["queue"].get()
proc["proc"].join()
procList = updatedProcList
# if I don't do this, I get "OSError too many open files" at somewhere
# around 375-400 child processes
time.sleep(.005)
# I could instead embed the list comprehension in the while statement,
# but that would hinder troubleshooting
remainingProcs = [ 1 ]
while len(remainingProcs) > 0:
remainingProcs = [ proc for proc in procList if proc["proc"].is_alive()]
time.sleep(1)
for proc in procList:
# kill zombies
if proc["queue"].empty():
MissedDevices += [ proc["dev"] ]
else:
queueData = proc["queue"].get()
resultsTable += [ queueData ]
while not proc["queue"].empty():
# drain whatever's left before we closeout the process
proc["queue"].get()
proc["proc"].join()
return { "Table" : resultsTable, "Missed" : MissedDevices }
You should be using multithreading with a multithreading pool (which can easily handle up to 500 threads) based on seeing that getAttributesOneDevice spends almost all of its time waiting for a network request to complete. You should also use a requests.Session object for doing the GET requests because according to the documentation:
The Session object allows you to persist certain parameters across requests. It also persists cookies across all requests made from the Session instance, and will use urllib3’s connection pooling. So if you’re making several requests to the same host, the underlying TCP connection will be reused, which can result in a significant performance increase (see HTTP persistent connection).
The worker function, getAttributesOneDevice, should be modified to raise an exception if it fails to capture a device.
import urllib, requests, json, sys, pprint, time
from multiprocessing.pool import ThreadPool
from functools import partial
assert sys.version_info >= (3, 6), "Must use Python 3.6+"
###########################
### handler function for multiprocessing worker
###########################
def getAttributesOneDevice(session, serverURL, token, device):
"""Handler function for getting a single device"""
query = { "hostname" : device }
response = session.get(serverURL, params = query, verify = False)
# Raise an exception if unable to capture a device
response.raise_for_status()
# Should the response itself be checked to ensure a device was captured
# and an expection be raised if not?
return response.json()
def GetDeviceAttributes(server = "", token = "", devices = []):
"""
See this URL for explanation of what this function does
https://github.com/NetBrainAPI/NetBrain-REST-API-V8.02/blob/master
/REST APIs Documentation/Devices Management
/Get Device Attributes API.md
To summarize the URL: will acquire detailed device attributes for a single
device.
This subroutine therefore queries for all devices provided, and assemble the
results into a single list of dicts.
Server queries are relatively expensive. A single query is not a big deal,
but accumulated across a massive device list and this can take excessive
time to execute (20min, etc). Therefor, this procedure is parallelized
through multi-processing to complete in a reasonable amount of time.
'server' should be a string that is just the http(s)://<FQDN>. Do not
incude the trailing '/'.
'token' should be an authentication token that was generated by
GetLoginToken and SetWorkingDomain modules in this directory.
'devices' should be a list of strings, where each entry is a device.
return a single dictionary:
"Table" a list of dicts, each dict the detailed attributes of a device
"Missed" a list of devices that had no result
Note that failure to capture a device is distinct from function failure.
"""
with requests.Session() as session, \
ThreadPool(min(len(devices), 500)) as pool:
session.headers = { "Accept" : "application/json",
"Content-Type" : "application/json",
"token" : token }
# Compute this once here:
serverURL = server + "/ServicesAPI/API/V1/CMDB/Devices/Attributes"
# The serverUrl and token arguments never vary:
worker = partial(getAttributesOneDevicesession, serverURL, token)
resultsTable = []
MissedDevices = []
results = pool.imap(worker, devices)
device_index = 0
while True:
try:
result.append(results.__next__())
except StopIteration:
break
except:
# This is the device that caused the exception.
# The assumption is that devices is indexable:
MissedDevices.append(devices[device_index])
finally:
device_index += 1
Thank you #Booboo and #Olvin-Roght for your help. I've marked Booboo's answer and the "accepted" answer, ultimately it was a combination of the two. Here's what I actually ended up with, for transparency and in case it helps anyone else in the future, but I only got here with the pointers from the folks on this thread. Thank you to everyone.
import requests, json, sys, concurrent.futures
# tested on 3.6.8
assert sys.version_info >= (3, 6), "Must use Python 3.6+"
#needed to REALLY condense the letters in the function name for submit()
def get1Dev(server, device, token):
"""Handler function for getting a single device. raises an exception if
unseccessful query, or returns the dict of the detailed attributes table."""
serverURL = server + "/ServicesAPI/API/V1/CMDB/Devices/Attributes"
query = { "hostname" : device }
headers = {
"Accept" : "application/json",
"Content-Type" : "application/json",
"token" : token
}
# ProcessPoolExecutor already handles all exceptions
response = requests.get(serverURL,headers=headers,params=query,verify=False)
if response.status_code is not 200:
raise Exception(str(response.status_code) + " returned from server")
responseStatusDescription = response.json()["statusDescription"]
if "Success." not in responseStatusDescription:
raise Exception(f"{responseStatusDescription} returned from server")
return response.json()["attributes"]
# end get1Dev()
def GetDeviceAttributes(server = "", token = "", devices = []):
"""
See this URL for explanation of what this function does
https://github.com/NetBrainAPI/NetBrain-REST-API-V8.02/blob/master
/REST APIs Documentation/Devices Management
/Get Device Attributes API.md
To summarize the URL: will acquire detailed device attributes for a single
This subroutine therefore queries for all devices provided, and assemble the
results into a single list of dicts.
Server queries are relatively expensive. A single query is not a big deal,
but accumulated across a massive device list and this can take excessive
time to execute (20min, etc). Therefor, this procedure is parallelized
through multi-processing to complete in a reasonable amount of time.
'server' should be a string that is just the http(s)://<FQDN>. Do not
incude the trailing '/'.
'token' should be an authentication token that was generated by
GetLoginToken and SetWorkingDomain modules in this directory.
'devices' should be a list of strings, where each entry is a device.
return a single dictionary:
"Table" a list of dicts, each dict the detailed attributes of a device
"Missed" a list of devices that had no result
Note that failure to capture a device is distinct from function failure.
"""
# will raise() if needed; purposefully not including in a "try" block
inputValidate(server, token, devices)
# remove all empty strings in 'devices'
devices = [ dev for dev in devices if dev ]
resultsTable = []
MissedDevices = []
# profiling data for max_workers, 06 Feb 2022:
# single-threaded single-processed: 20min (ish)
# 61 is default, took 6min
# 10000 resulted in "too many open files" error
# detailed analysis: "too many open files" occurs ~350-400 child procs
# 300 took 40sec, so we'll call that the "sweet spot"
exe = concurrent.futures.ProcessPoolExecutor(max_workers=300)
# exe.map() is too trivial in that it halts on first exception :(
# derived from example for ThreadPoolExecutor
# https://docs.python.org/3/library/concurrent.futures.html
results = {exe.submit(get1Dev, server, dev, token): dev for dev in devices}
for result in concurrent.futures.as_completed(results, timeout=300):
try:
# separating calling result() from the += allows exception handling
# without worrying about adding error result to resultsTable
res = result.result()
resultsTable += [ res ]
# need to catch both exceptions
except (Exception, concurrent.futures.TimeoutError):
# not sure why, results[result] is related dev in devices, as desired
MissedDevices += [ results[result] ]
return { "Table" : resultsTable, "Missed" : MissedDevices }

Python CAN receiving data (pdo mapping)

I am currently working on making a CAN tracer with Python. The connection as well as the received data from CAN work. My question now is: how can I change the PDO mapping and stop and start the transmission via CAN like it works with CANopen?
import canopen
# CAN Setting
can_interface = '0'
can_filters = [{"can_id":0x018A, "can_mask": 0xFFFF, "extended": True}]
bus = can.interface.Bus(can_interface, bustype='ixxat',can_filters=can_filters)
while True:
message = bus.recv()
print(message)
You should be in pre operational mode and send SDO request to change PDO mapping.
You can get these information in the documentation or in the EDS file.
If doing modification do not forget to send the save SDO request at the end else the node will restart with its default value

Generate TSIG keyring (as encoded byte string) for DNS Update

I am trying to use python DNS module (dnspython) to create (add) new DNS record.
Documentation specifies how to create update http://www.dnspython.org/examples.html :
import dns.tsigkeyring
import dns.update
import sys
keyring = dns.tsigkeyring.from_text({
'host-example.' : 'XXXXXXXXXXXXXXXXXXXXXX=='
})
update = dns.update.Update('dyn.test.example', keyring=keyring)
update.replace('host', 300, 'a', sys.argv[1])
But it does not precise, how to actually generate keyring string that can be passed to dns.tsigkeyring.from_text() method in the first place.
What is the correct way to generate the key? I am using krb5 at my organization.
Server is running on Microsoft AD DNS with GSS-TSIG.
TSIG and GSS-TSIG are different beasts – the former uses a static preshared key that can be simply copied from the server, but the latter uses Kerberos (GSSAPI) to negotiate a session key for every transaction.
At the time when this thread was originally posted, dnspython 1.x did not have any support for GSS-TSIG whatsoever.
(The handshake does not result in a static key that could be converted to a regular TSIG keyring; instead the GSSAPI library itself must be called to build an authenticator – dnspython 1.x could not do that, although dnspython 2.1 finally can.)
If you are trying to update an Active Directory DNS server, BIND's nsupdate command-line tool supports GSS-TSIG (and sometimes it even works). You should be able to run it through subprocess and simply feed the necessary updates via stdin.
cmds = [f'zone {dyn_zone}\n',
f'del {fqdn}\n',
f'add {fqdn} 60 TXT "{challenge}"\n',
f'send\n']
subprocess.run(["nsupdate", "-g"],
input="".join(cmds).encode(),
check=True)
As with most Kerberos client applications, nsupdate expects the credentials to be already present in the environment (that is, you need to have already obtained a TGT using kinit beforehand; or alternatively, if a recent version of MIT Krb5 is used, you can point $KRB5_CLIENT_KTNAME to the keytab containing the client credentials).
Update: dnspython 2.1 finally has the necessary pieces for GSS-TSIG, but creating the keyring is currently a very manual process – you have to call the GSSAPI library and process the TKEY negotiation yourself. The code for doing so is included at the bottom.
(The Python code below can be passed a custom gssapi.Credentials object, but otherwise it looks for credentials in the environment just like nsupdate does.)
import dns.rdtypes.ANY.TKEY
import dns.resolver
import dns.update
import gssapi
import socket
import time
import uuid
def _build_tkey_query(token, key_ring, key_name):
inception_time = int(time.time())
tkey = dns.rdtypes.ANY.TKEY.TKEY(dns.rdataclass.ANY,
dns.rdatatype.TKEY,
dns.tsig.GSS_TSIG,
inception_time,
inception_time,
3,
dns.rcode.NOERROR,
token,
b"")
query = dns.message.make_query(key_name,
dns.rdatatype.TKEY,
dns.rdataclass.ANY)
query.keyring = key_ring
query.find_rrset(dns.message.ADDITIONAL,
key_name,
dns.rdataclass.ANY,
dns.rdatatype.TKEY,
create=True).add(tkey)
return query
def _probe_server(server_name, zone):
gai = socket.getaddrinfo(str(server_name),
"domain",
socket.AF_UNSPEC,
socket.SOCK_DGRAM)
for af, sf, pt, cname, sa in gai:
query = dns.message.make_query(zone, "SOA")
res = dns.query.udp(query, sa[0], timeout=2)
return sa[0]
def gss_tsig_negotiate(server_name, server_addr, creds=None):
# Acquire GSSAPI credentials
gss_name = gssapi.Name(f"DNS#{server_name}",
gssapi.NameType.hostbased_service)
gss_ctx = gssapi.SecurityContext(name=gss_name,
creds=creds,
usage="initiate")
# Name generation tips: https://tools.ietf.org/html/rfc2930#section-2.1
key_name = dns.name.from_text(f"{uuid.uuid4()}.{server_name}")
tsig_key = dns.tsig.Key(key_name, gss_ctx, dns.tsig.GSS_TSIG)
key_ring = {key_name: tsig_key}
key_ring = dns.tsig.GSSTSigAdapter(key_ring)
token = gss_ctx.step()
while not gss_ctx.complete:
tkey_query = _build_tkey_query(token, key_ring, key_name)
response = dns.query.tcp(tkey_query, server_addr, timeout=5)
if not gss_ctx.complete:
# Original comment:
# https://github.com/rthalley/dnspython/pull/530#issuecomment-658959755
# "this if statement is a bit redundant, but if the final token comes
# back with TSIG attached the patch to message.py will automatically step
# the security context. We dont want to excessively step the context."
token = gss_ctx.step(response.answer[0][0].key)
return key_name, key_ring
def gss_tsig_update(zone, update_msg, creds=None):
# Find the SOA of our zone
answer = dns.resolver.resolve(zone, "SOA")
soa_server = answer.rrset[0].mname
server_addr = _probe_server(soa_server, zone)
# Get the GSS-TSIG key
key_name, key_ring = gss_tsig_negotiate(soa_server, server_addr, creds)
# Dispatch the update
update_msg.use_tsig(keyring=key_ring,
keyname=key_name,
algorithm=dns.tsig.GSS_TSIG)
response = dns.query.tcp(update_msg, server_addr)
return response

How do I connect dbus and policykit to my function in python?

I am making a python application that has a method needing root privileges. From https://www.freedesktop.org/software/polkit/docs/0.105/polkit-apps.html, I found Example 2. Accessing the Authority via D-Bus which is the python version of the code below, I executed it and I thought I'd be able to get root privileges after entering my password but I'm still getting "permission denied" on my app. This is the function I'm trying to connect
import dbus
bus = dbus.SystemBus()
proxy = bus.get_object('org.freedesktop.PolicyKit1', '/org/freedesktop/PolicyKit1/Authority')
authority = dbus.Interface(proxy, dbus_interface='org.freedesktop.PolicyKit1.Authority')
system_bus_name = bus.get_unique_name()
subject = ('system-bus-name', {'name' : system_bus_name})
action_id = 'org.freedesktop.policykit.exec'
details = {}
flags = 1 # AllowUserInteraction flag
cancellation_id = '' # No cancellation id
result = authority.CheckAuthorization(subject, action_id, details, flags, cancellation_id)
print result
In the python code you quoted, does result indicate success or failure? If it fails, you need to narrow down the error by first of all finding out what the return values of bus, proxy, authority and system_bus_name are. If it succeeds, you need to check how you are using the result.

How to get client's IP in a python thrift server

I'm writing a thrift service in python and I would like to understand how I can
get the client's IP address in the handler functions context.
Thanks,
Love.
You need to obtain the transport, and get the data from there. Not sure how to do this exactly in Python, but there's a mailing list thread and there's this JIRA-ticket THRIFT-1053 describing a solution for C++/Java.
This is the relevant part from the mailing list thread:
I did it by decorating the TProcessor like this psuedo-code.
-craig
class TrackingProcessor implements TProcessor {
TrackingProcessor (TProcessor processor) {this.processor=processor;}
public boolean process(TProtocol in, TProtocol out) throws TException {
TTransport t = in.getTransport();
InetAddress ia = t instanceof TSocket ?
((TSocket)t).getSocket().getInetAddress() : null;
// Now you have the IP address, so what ever you want.
// Delegate to the processor we are decorating.
return processor.process(in,out);
}
}
This is a bit old but I'm currently solving the same problem.
Here's my solution with thriftpy:
import thriftpy
from thriftpy.thrift import TProcessor, TApplicationException, TType
from thriftpy.server import TThreadedServer
from thriftpy.protocol import TBinaryProtocolFactory
from thriftpy.transport import TBufferedTransportFactory, TServerSocket
class CustomTProcessor(TProcessor):
def process_in(self, iprot):
api, type, seqid = iprot.read_message_begin()
if api not in self._service.thrift_services:
iprot.skip(TType.STRUCT)
iprot.read_message_end()
return api, seqid, TApplicationException(TApplicationException.UNKNOWN_METHOD), None # noqa
args = getattr(self._service, api + "_args")()
args.read(iprot)
iprot.read_message_end()
result = getattr(self._service, api + "_result")()
# convert kwargs to args
api_args = [args.thrift_spec[k][1] for k in sorted(args.thrift_spec)]
# get client IP address
client_ip, client_port = iprot.trans.sock.getpeername()
def call():
f = getattr(self._handler, api)
return f(*(args.__dict__[k] for k in api_args), client_ip=client_ip)
return api, seqid, result, call
class PingPongDispatcher:
def ping(self, param1, param2, client_ip):
return "pong %s" % client_ip
pingpong_thrift = thriftpy.load("pingpong.thrift")
processor = CustomTProcessor(pingpong_thrift.PingService, PingPongDispatcher())
server_socket = TServerSocket(host="127.0.0.1", port=12345, client_timeout=10000)
server = TThreadedServer(processor,
server_socket,
iprot_factory=TBinaryProtocolFactory(),
itrans_factory=TBufferedTransportFactory())
server.serve()
Remember that every method in the dispatcher will be called with extra parameter client_ip
The only way I found to get the TProtocol at the service handler is to extend the processor and create one handler instance for each client related by transport/protocol. Here's an example:
public class MyProcessor implements TProcessor {
// Maps sockets to processors
private static final Map<, Processor<ServiceHandler>> PROCESSORS = Collections.synchronizedMap(new HashMap<String, Service.Processor<ServiceHandler>>());
// Maps sockets to handlers
private static final Map<String, ServiceHandler> HANDLERS = Collections.synchronizedMap(new HashMap<String, ServiceHandler>());
#Override
public boolean process(final TProtocol in, final TProtocol out)
throws TException {
// Get the socket for this request
final TTransport t = in.getTransport();
// Note that this cast will fail if the transport is not a socket, so you might want to add some checking.
final TSocket socket = (TSocket) t;
// Get existing processor for this socket if any
Processor<ServiceHandler> processor = PROCESSORS.get(socket);
// If there's no processor, create a processor and a handler for
// this client and link them to this new socket
if (processor == null) {
// Inform the handler of its socket
final ServiceHandler handler = new ServiceHandler(socket);
processor = new Processor<ServiceHandler>(handler);
PROCESSORS.put(clientRemote, processor);
HANDLERS.put(clientRemote, handler);
}
return processor.process(in, out);
}
}
Then you need to tell Thrift to use this processor for incoming requests. For a TThreadPoolServer it goes like this:
final TThreadPoolServer.Args args = new TThreadPoolServer.Args(new TServerSocket(port));
args.processor(new MyProcessor());
final TThreadPoolServer server = new TThreadPoolServer(args);
The PROCESSORS map might look superfluous, but it is not since there's no way to get the handler for a processor (i.e. there's no getter).
Note that it is your ServiceHandler instance that needs to keep which socket it is associated to. Here I pass it on the constructor but any way will do. Then when the ServiceHandler's IFace implementation is called, it will already have the associated Socket.
This also means you will have an instance of MyProcessor and ServiceHandler for each connected client, which I think is not the case with base Thrift where only one instance of each of the classes are created.
This solution also has a quite annoying drawback: you need to figure out a method to remove obsolete data from PROCESSORS and HANDLERS maps, otherwise these maps will grow indefinitely. In my case each client has a unique ID, so I can check if there are obsolete sockets for this client and remove them from the maps.
PS: the Thrift guys need to figure out a way to let the service handler get the used protocol for current call (for example by allowing to extend a base class instead of implementing an interface). This is very useful in so many scenarios.

Categories

Resources