root/openid/consumer.py

Revision 16, 18.9 kB (checked in by devja..@anarkystic.com, 1 year ago)

initial import of the openid-on-twisted bits

Line 
1 """
2 This stuff is copied liberally from the JanRain open id code and modified
3 where appropriate to work within Twisted.
4
5 The original code has an Apache 2.0 License.
6 """
7
8
9 from pprint import pprint
10
11 from twisted.application import service
12 from twisted.internet import defer, reactor, protocol
13 from twisted.python import usage
14 from twisted.web import client
15
16 # here they come...
17 from openid.yadis import discover
18 from openid.fetchers import HTTPResponse
19 from openid.yadis.services import applyFilter as extractServices
20 from openid.consumer.discover import \
21         OpenIDServiceEndpoint, getOPOrUserServices
22 from openid.message import Message, OPENID_NS, OPENID2_NS, OPENID1_NS, \
23      IDENTIFIER_SELECT, no_default
24 from openid.yadis.etxrd import XRDSError
25 from openid.association import Association, default_negotiator
26 from openid.consumer import consumer
27 from openid import oidutil
28 from openid.store.nonce import mkNonce, split as splitNonce
29 from openid import sreg
30
31
32
33 from openid.yadis.constants import \
34         YADIS_HEADER_NAME, YADIS_CONTENT_TYPE, YADIS_ACCEPT_HEADER
35        
36 class HTTPClientFactory(client.HTTPClientFactory):
37     """
38     An HTTPClientFactory that returns HTTPResponse objects instead of just
39     page contents
40     """
41     def page(self, page):
42         if self.waiting:
43             self.waiting = 0
44
45             headers = dict([(k, v[0]) for (k, v)
46                             in self.response_headers.iteritems()])
47
48             response = HTTPResponse(self.url,
49                                     int(self.status),
50                                     headers,
51                                     page)
52             self.deferred.callback(response)
53
54
55 def getPage(url, contextFactory=None, headers=None,
56             *args, **kw):
57     """
58     Download a page. Return a deferred, which will callback with a
59     response object or errback with a description of the error.
60
61     See HTTPClientFactory to see what extra args can be passed.
62     """
63     if headers is None:
64         headers = {}
65
66     if kw.get("method", "GET") == "POST":
67         headers.setdefault("Content-type",
68                            "application/x-www-form-urlencoded")
69    
70
71     scheme, host, port, path = client._parse(url)
72     factory = HTTPClientFactory(url, headers=headers, *args, **kw)
73     if scheme == 'https':
74         from twisted.internet import ssl
75         if contextFactory is None:
76             contextFactory = ssl.ClientContextFactory()
77         reactor.connectSSL(host, port, factory, contextFactory)
78     else:
79         reactor.connectTCP(host, port, factory)
80     return factory.deferred
81
82
83 class OpenIDConsumerService(service.Service):
84     store = None
85     def __init__(self, config):
86         self.config = config
87        
88         self.consumer = OpenIDConsumer(self.store)
89         self.discoverer = OpenIDDiscoverer()
90
91     def startService(self):
92         service.Service.startService(self)
93         self.test()
94
95     def test(self):
96         url = "http://termie.wordpress.com"
97         d = self.discover(url)
98         d.addBoth(lambda _: pprint(_) and _ or _)
99         def _assoc(rv):
100             if len(rv[1]):
101                 print "Asdasd"
102                 d2 = self.associate(rv[1][0])
103                 d2.addCallback(lambda a: (a, rv[1][0]))
104                 return d2
105
106         d.addCallback(_assoc)
107         d.addBoth(lambda _: pprint(_) and _ or _)
108
109         def _build_request(assoc_and_endpoint):
110             assoc, endpoint = assoc_and_endpoint
111             req = self.buildRequest(endpoint, assoc)
112             return req
113         d.addCallback(_build_request)
114
115         def _add_sreg(req):
116             sreg_req = sreg.SRegRequest(optional=['nickname', 'email'])
117             req.addExtension(sreg_req)
118             return req
119         d.addCallback(_add_sreg)
120
121         def _get_url(req):
122             return req.redirectURL("http://flek.local", "http://flek.local")
123         d.addCallback(_get_url)
124         d.addBoth(lambda _: pprint(_) and _ or _)
125        
126         return d
127    
128     # returns an association thingee
129     def associate(self, endpoint):
130         return self.consumer.getAssociation(endpoint)
131        
132     # returns a url and list of services found
133     def discover(self, uri):
134         return self.discoverer.discover(uri)
135
136     # builds a request
137     def buildRequest(self, endpoint, assoc, anonymous=False,
138                      return_to_args=None):
139         return self.consumer.buildRequest(endpoint, assoc, anonymous,
140                                           return_to_args)
141
142
143 class OpenIDDiscoverer(object):
144     # returns a url and list of services found
145     def discover(self, uri):
146         d = self._discoverYadis(uri)
147        
148         def _fill_url(resp):
149             yadis_url = resp.normalized_uri
150             return resp, yadis_url
151
152         def _fill_services(resp_and_url):
153             resp, yadis_url = resp_and_url
154             try:
155                 openid_services = extractServices(
156                     resp.normalized_uri, resp.response_text,
157                     OpenIDServiceEndpoint)
158             except XRDSError:
159                 # Does not parse as a Yadis XRDS file
160                 openid_services = []
161
162             if not openid_services:
163                 # Either not an XRDS or there are no OpenID services.
164
165                 if resp.isXRDS():
166                     # if we got the Yadis content-type or followed the Yadis
167                     # header, re-fetch the document without following the
168                     # Yadis
169                     # header, with no Accept header.
170                     return self._discoverNoYadis(uri)
171                 else:
172                     body = resp.response_text
173
174                 # Try to parse the response as HTML to get OpenID 1.0/1.1
175                 # <link rel="...">
176                 openid_services = OpenIDServiceEndpoint.fromHTML(yadis_url,
177                                                                  body)
178             return (yadis_url, getOPOrUserServices(openid_services))
179        
180         d.addCallback(_fill_url)
181         d.addCallback(_fill_services)
182
183         return d
184    # supporting for discover
185     def _discoverYadis(self, uri):
186         result = discover.DiscoveryResult(uri)
187        
188         d = getPage(uri, headers={"Accept": YADIS_ACCEPT_HEADER})
189
190         def _fill_result(resp):
191             result.normalized_uri = resp.final_url
192             result.xrds_uri = discover.whereIsYadis(resp)
193
194             if result.xrds_uri and result.usedYadisLocation():
195                 d2 = getPage(result.xrds_uri)
196                
197                 return d2
198             return resp
199            
200         def _fill_content_type(resp):
201             result.content_type = resp.headers.get('content-type')
202             return resp
203          
204         def _fill_response_text(resp):
205             result.response_text = resp.body
206
207         d.addCallback(_fill_result)
208         d.addCallback(_fill_content_type)
209         d.addCallback(_fill_response_text)
210         d.addCallback(lambda _: result)
211
212         return d
213
214     def _discoverNoYadis(self, uri):
215         d = getPage(uri)
216        
217         def _fill_services(resp):
218             claimed_id = resp.final_url
219             openid_services = OpenIDServiceEndpoint.fromHTML(claimed_id,
220                                                              resp.body)
221             return (claimed_id, openid_services)
222         d.addCallback(_fill_services)
223         return d
224
225
226 class OpenIDConsumer(consumer.GenericConsumer):
227     def buildRequest(self, service_endpoint, assoc=None, anonymous=False,
228                      return_to_args=None):
229         """Create an AuthRequest object for the specified
230         service_endpoint. This method will create an association if
231         necessary."""
232         if self.store is None:
233             d = defer.succeed(None)
234         elif assoc:
235             d = defer.succeed(assoc)
236         else:
237             d = self.getAssociation(service_endpoint)
238            
239         def _build_request(assoc):
240             request = consumer.AuthRequest(service_endpoint, assoc)
241             request.return_to_args[self.openid1_nonce_query_arg_name] = mkNonce()
242             try:
243                 request.setAnonymous(anonymous)
244             except ValueError, why:
245                 raise consumer.ProtocolError(str(why))
246
247             if return_to_args:
248                 request.return_to_args.update(return_to_args)
249
250             return request
251         d.addCallback(_build_request)
252         return d
253    
254     def verifyResponse(self, message, endpoint, return_to=None):
255         return defer.succeed(consumer.GenericConsumer.complete(self, message, endpoint, return_to))
256
257     # attempt to get a shared secret, return it
258     def getAssociation(self, endpoint):
259         if self.store:
260             assoc = self.store.getAssociation(endpoint.server_url)
261         else:
262             assoc = None
263         if assoc is not None and assoc.expiresIn > 0:
264             return defer.succeed(assoc)
265        
266         d = self._negotiateAssociation(endpoint)
267         def _cacheResult(r):
268             self.store.storeAssociation(endpoint.server_url, r)
269             return r
270         if self.store:
271             d.addCallback(_cacheResult)
272         return d
273
274
275     # supporting for association
276     def _negotiateAssociation(self, endpoint):
277         assoc_type, session_type = self.negotiator.getAllowedType()
278
279         d = self._requestAssociation(endpoint, assoc_type, session_type)
280
281         def _handle_ServerError(f):
282             print "Server derror!", f
283             f.trap(consumer.ServerError)
284             why = f.exc_value
285
286             # Any error message whose code is not 'unsupported-type'
287             # should be considered a total failure.
288             if why.error_code != 'unsupported-type' or \
289                    why.message.isOpenID1():
290                 oidutil.log(
291                     'Server error when requesting an association from %r: %s'
292                     % (endpoint.server_url, why.error_text))
293                 return None
294
295             # The server didn't like the association/session type
296             # that we sent, and it sent us back a message that
297             # might tell us how to handle it.
298             oidutil.log(
299                 'Unsupported association type %s: %s' % (assoc_type,
300                                                          why.error_text,))
301
302             # Extract the session_type and assoc_type from the
303             # error message
304             assoc_type = why.message.getArg(OPENID_NS, 'assoc_type')
305             session_type = why.message.getArg(OPENID_NS, 'session_type')
306
307             if assoc_type is None or session_type is None:
308                 oidutil.log('Server responded with unsupported association '
309                             'session but did not supply a fallback.')
310                 return None
311             elif not self.negotiator.isAllowed(assoc_type, session_type):
312                 fmt = ('Server sent unsupported session/association type: '
313                        'session_type=%s, assoc_type=%s')
314                 oidutil.log(fmt % (session_type, assoc_type))
315                 return None
316             else:
317                 # Attempt to create an association from the assoc_type
318                 # and session_type that the server told us it
319                 # supported.
320                 d2 = self._requestAssociation(
321                         endpoint, assoc_type, session_type)
322                 return d2
323                 #except ServerError, why:
324                 #    # Do not keep trying, since it rejected the
325                 #    # association type that it told us to use.
326                 #    oidutil.log('Server %s refused its suggested association '
327                 #                'type: session_type=%s, assoc_type=%s'
328                 #                % (endpoint.server_url, session_type,
329                 #                   assoc_type))
330                 #    return None
331                 #else:
332                 #    return assoc
333
334         d.addErrback(_handle_ServerError)
335        
336         d.addErrback(lambda f: f.printTraceback())
337         return d
338
339     def _requestAssociation(self, endpoint, assoc_type, session_type):
340         assoc_session, args = self._createAssociateRequest(
341             endpoint, assoc_type, session_type)
342
343         d = self._makeKVPost(args, endpoint.server_url)
344
345         d.addCallback(self._extractAssociation, assoc_session)
346        
347         return d
348         #try:
349         #    assoc = self._extractAssociation(response, assoc_session)
350         #except KeyError, why:
351         #    oidutil.log('Missing required parameter in response from %s: %s'
352         #                % (endpoint.server_url, why[0]))
353         #    return None
354         #except ProtocolError, why:
355         #    oidutil.log('Protocol error parsing response from %s: %s' % (
356         #        endpoint.server_url, why[0]))
357         #    return None
358         #else:
359         #    return assoc
360    
361     def _createAssociateRequest(self, endpoint, assoc_type, session_type):
362         session_type_class = self.session_types[session_type]
363         assoc_session = session_type_class()
364
365         args = {
366             'mode': 'associate',
367             'assoc_type': assoc_type,
368             }
369
370         if not endpoint.compatibilityMode():
371             args['ns'] = OPENID2_NS
372
373         # Leave out the session type if we're in compatibility mode
374         # *and* it's no-encryption.
375         if (not endpoint.compatibilityMode() or
376             assoc_session.session_type != 'no-encryption'):
377             args['session_type'] = assoc_session.session_type
378
379         args.update(assoc_session.getRequest())
380         message = Message.fromOpenIDArgs(args)
381         return assoc_session, message
382
383     def _makeKVPost(self, args, endpoint_url):
384      
385         d = getPage(endpoint_url, method="POST", postdata=args.toURLEncoded())
386            
387         def _handleResponse(resp):
388             response_message = Message.fromKVForm(resp.body)
389    
390             if resp.status == 400:
391                 raise ServerError.fromMessage(response_message)
392             elif resp.status != 200:
393                 fmt = 'bad status code from server %s: %s'
394                 error_message = fmt % (endpoint_url, resp.status)
395                 raise Exception(error_message)
396
397             return response_message
398         d.addCallback(_handleResponse)
399         return d
400
401     def _extractAssociation(self, assoc_response, assoc_session):
402         # Extract the common fields from the response, raising an
403         # exception if they are not found
404         assoc_type = assoc_response.getArg(
405             OPENID_NS, 'assoc_type', no_default)
406         assoc_handle = assoc_response.getArg(
407             OPENID_NS, 'assoc_handle', no_default)
408
409         # expires_in is a base-10 string. The Python parsing will
410         # accept literals that have whitespace around them and will
411         # accept negative values. Neither of these are really in-spec,
412         # but we think it's OK to accept them.
413         expires_in_str = assoc_response.getArg(
414             OPENID_NS, 'expires_in', no_default)
415         try:
416             expires_in = int(expires_in_str)
417         except ValueError, why:
418             raise consumer.ProtocolError('Invalid expires_in field: %s' % (why[0],))
419
420         # OpenID 1 has funny association session behaviour.
421         if assoc_response.isOpenID1():
422             session_type = self._getOpenID1SessionType(assoc_response)
423         else:
424             session_type = assoc_response.getArg(
425                 OPENID2_NS, 'session_type', no_default)
426
427         # Session type mismatch
428         if assoc_session.session_type != session_type:
429             if (assoc_response.isOpenID1() and
430                 session_type == 'no-encryption'):
431                 # In OpenID 1, any association request can result in a
432                 # 'no-encryption' association response. Setting
433                 # assoc_session to a new no-encryption session should
434                 # make the rest of this function work properly for
435                 # that case.
436                 assoc_session = consumer.PlainTextConsumerSession()
437             else:
438                 # Any other mismatch, regardless of protocol version
439                 # results in the failure of the association session
440                 # altogether.
441                 fmt = 'Session type mismatch. Expected %r, got %r'
442                 message = fmt % (assoc_session.session_type, session_type)
443                 raise consumer.ProtocolError(message)
444
445         # Make sure assoc_type is valid for session_type
446         if assoc_type not in assoc_session.allowed_assoc_types:
447             fmt = 'Unsupported assoc_type for session %s returned: %s'
448             raise consumer.ProtocolError(fmt % (assoc_session.session_type, assoc_type))
449
450         # Delegate to the association session to extract the secret
451         # from the response, however is appropriate for that session
452         # type.
453         try:
454             secret = assoc_session.extractSecret(assoc_response)
455         except ValueError, why:
456             fmt = 'Malformed response for %s session: %s'
457             raise consumer.ProtocolError(fmt % (assoc_session.session_type, why[0]))
458
459         return Association.fromExpiresIn(
460             expires_in, assoc_handle, secret, assoc_type)
461
462     def _getOpenID1SessionType(self, assoc_response):
463         """Given an association response message, extract the OpenID
464         1.X session type.
465
466         This function mostly takes care of the 'no-encryption' default
467         behavior in OpenID 1.
468
469         If the association type is plain-text, this function will
470         return 'no-encryption'
471
472         @returns: The association type for this message
473         @rtype: str
474
475         @raises KeyError: when the session_type field is absent.
476         """
477         # If it's an OpenID 1 message, allow session_type to default
478         # to None (which signifies "no-encryption")
479         session_type = assoc_response.getArg(OPENID1_NS, 'session_type')
480
481         # Handle the differences between no-encryption association
482         # respones in OpenID 1 and 2:
483
484         # no-encryption is not really a valid session type for
485         # OpenID 1, but we'll accept it anyway, while issuing a
486         # warning.
487         if session_type == 'no-encryption':
488             oidutil.log('WARNING: OpenID server sent "no-encryption"'
489                         'for OpenID 1.X')
490
491         # Missing or empty session type is the way to flag a
492         # 'no-encryption' response. Change the session type to
493         # 'no-encryption' so that it can be handled in the same
494         # way as OpenID 2 'no-encryption' respones.
495         elif session_type == '' or session_type is None:
496             session_type = 'no-encryption'
497
498         return session_type
499
500
501 class OpenIDConsumerOptions(usage.Options):
502     optParameters = [
503             ['openid_realm', None, 'http://example.com'],
504             #['spread_poll_interval', None, '0.05'],
505             #['spread_host', None, 'localhost'],
506             #['spread_port', None, str(spread.DEFAULT_SPREAD_PORT)],
507             #['spread_private', None, 'dummy'],
508             #['spread_membership', None, '1'],