diff --git a/LICENSE b/LICENSE index 23cb790..d511905 100644 --- a/LICENSE +++ b/LICENSE @@ -1,12 +1,12 @@ - GNU GENERAL PUBLIC LICENSE - Version 2, June 1991 + GNU GENERAL PUBLIC LICENSE + Version 2, June 1991 - Copyright (C) 1989, 1991 Free Software Foundation, Inc., + Copyright (C) 1989, 1991 Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. - Preamble + Preamble The licenses for most software are designed to take away your freedom to share and change it. By contrast, the GNU General Public @@ -56,7 +56,7 @@ patent must be licensed for everyone's free use or not licensed at all. The precise terms and conditions for copying, distribution and modification follow. - GNU GENERAL PUBLIC LICENSE + GNU GENERAL PUBLIC LICENSE TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 0. This License applies to any program or other work which contains @@ -255,7 +255,7 @@ make exceptions for this. Our decision will be guided by the two goals of preserving the free status of all derivatives of our free software and of promoting the sharing and reuse of software generally. - NO WARRANTY + NO WARRANTY 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN @@ -277,9 +277,9 @@ YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. - END OF TERMS AND CONDITIONS + END OF TERMS AND CONDITIONS - How to Apply These Terms to Your New Programs + How to Apply These Terms to Your New Programs If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it @@ -290,8 +290,8 @@ to attach them to the start of each source file to most effectively convey the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. - {description} - Copyright (C) {year} {fullname} + + Copyright (C) This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -329,7 +329,7 @@ necessary. Here is a sample; alter the names: Yoyodyne, Inc., hereby disclaims all copyright interest in the program `Gnomovision' (which makes passes at compilers) written by James Hacker. - {signature of Ty Coon}, 1 April 1989 + , 1 April 1989 Ty Coon, President of Vice This General Public License does not permit incorporating your program into diff --git a/app.py b/app.py new file mode 100755 index 0000000..a31a9af --- /dev/null +++ b/app.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# CirruxCache provides dynamic HTTP caching on AppEngine (CDN like) +# Copyright (C) 2009 Samuel Alba +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; either version 2 +# of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +# + +"""CirruxCache provides dynamic HTTP caching on AppEngine (CDN like) + +http://cirrux.org/cache/ +http://code.google.com/p/cirruxcache/ +""" + +__version__ = '0.4.2' +__author__ = [ + 'Samuel Alba ' +] +__license__ = 'GNU Public License version 2' + +import sys, os +root = os.path.dirname(os.environ['PATH_TRANSLATED']) +sys.path.append(os.path.join(root, 'lib')) +sys.path.append(os.path.join(root, 'contrib')) + +import logging +import web + +import config +from services.cron import Cron +from services.admin import Admin +from services.store import Store +from services.debug import Debug + + +class Root(object): + + def GET(self, request): + web.header('Content-Type', 'text/plain') + content = 'CirruxCache (%s) / http://code.google.com/p/cirruxcache/\n' % __version__ + if request: + raise web.HTTPError(status='404 Not Found', data=content) + return content + +class VhostMapper(object): + + def __iter__(self): + base = ( + '/_admin/(.*)', 'Admin', + '/_store/(.*)', 'Store', + '/_cron/(.*)', 'Cron' + ) + url = () + urls = config.urls + if 'HTTP_HOST' in web.ctx.environ: + vhost = web.ctx.environ['HTTP_HOST'] + if vhost in urls: + url = urls[vhost] + elif 'default' in urls: + url = urls['default'] + return iter(base + url + ('/(.*)', 'Root')) + +if __name__ == '__main__': +# logging.getLogger().setLevel(logging.DEBUG) + app = web.application(VhostMapper(), globals()) + app.cgirun() diff --git a/app.yaml b/app.yaml new file mode 100644 index 0000000..c30ba20 --- /dev/null +++ b/app.yaml @@ -0,0 +1,38 @@ +# CirruxCache provides dynamic HTTP caching on AppEngine (CDN like) +# Copyright (C) 2009 Samuel Alba +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; either version 2 +# of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +# + +application: appname +version: 1 +runtime: python +api_version: 1 + +handlers: +# Uncomment theses 3 lignes below to serve a static crossdomain.xml file +#- url: /crossdomain\.xml +# static_files: static/crossdomain.xml +# upload: static/crossdomain.xml + +- url: /_static/((admin|jquery\.form|wtooltip)\.(js|css)) + static_files: static/\1 + upload: static/(admin|jquery\.form|wtooltip)\.(js|css) + +- url: /.* + script: app.py + +derived_file_type: +- python_precompiled diff --git a/config.py b/config.py new file mode 100644 index 0000000..e30b86b --- /dev/null +++ b/config.py @@ -0,0 +1,47 @@ +# CirruxCache provides dynamic HTTP caching on AppEngine (CDN like) +# Copyright (C) 2009 Samuel Alba +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; either version 2 +# of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +# + +"""CirruxCache configuration file +""" + +from lib import cache, redirect, forward, image +urls = {} + + +# EDIT BELOW + +urls['default'] = ( + #'(/debug/.*)', 'Debug', + '(/data/.*)', 'config.Static', + '/www(/.*)', 'config.Www' + ) + +# POP definition +# You can define and configure your Point Of Presence + +class Static(cache.Service): + origin = 'http://static.mydomain.tld' + maxTTL = 2592000 # 1 month + ignoreQueryString = True + +class Www(cache.Service): + origin = 'http://www.mydomain.tld' + allowFlushFrom = ['127.0.0.1'] + forceTTL = 3600 # 1 hour + ignoreQueryString = True + forwardPost = False diff --git a/contrib/web/__init__.py b/contrib/web/__init__.py new file mode 100644 index 0000000..337d9bf --- /dev/null +++ b/contrib/web/__init__.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +"""web.py: makes web apps (http://webpy.org)""" + +from __future__ import generators + +__version__ = "0.34" +__author__ = [ + "Aaron Swartz ", + "Anand Chitipothu " +] +__license__ = "public domain" +__contributors__ = "see http://webpy.org/changes" + +import utils, db, net, wsgi, http, webapi, httpserver, debugerror +import template, form + +import session + +from utils import * +from db import * +from net import * +from wsgi import * +from http import * +from webapi import * +from httpserver import * +from debugerror import * +from application import * +from browser import * +import test +try: + import webopenid as openid +except ImportError: + pass # requires openid module + diff --git a/contrib/web/application.py b/contrib/web/application.py new file mode 100755 index 0000000..4104fe7 --- /dev/null +++ b/contrib/web/application.py @@ -0,0 +1,670 @@ +#!/usr/bin/python +""" +Web application +(from web.py) +""" +import webapi as web +import webapi, wsgi, utils +import debugerror +from utils import lstrips, safeunicode +import sys + +import urllib +import traceback +import itertools +import os +import re +import types +from exceptions import SystemExit + +try: + import wsgiref.handlers +except ImportError: + pass # don't break people with old Pythons + +__all__ = [ + "application", "auto_application", + "subdir_application", "subdomain_application", + "loadhook", "unloadhook", + "autodelegate" +] + +class application: + """ + Application to delegate requests based on path. + + >>> urls = ("/hello", "hello") + >>> app = application(urls, globals()) + >>> class hello: + ... def GET(self): return "hello" + >>> + >>> app.request("/hello").data + 'hello' + """ + def __init__(self, mapping=(), fvars={}, autoreload=None): + if autoreload is None: + autoreload = web.config.get('debug', False) + self.mapping = mapping + self.fvars = fvars + self.processors = [] + + self.add_processor(loadhook(self._load)) + self.add_processor(unloadhook(self._unload)) + + if autoreload: + def main_module_name(): + mod = sys.modules['__main__'] + file = getattr(mod, '__file__', None) # make sure this works even from python interpreter + return file and os.path.splitext(os.path.basename(file))[0] + + def modname(fvars): + """find name of the module name from fvars.""" + file, name = fvars.get('__file__'), fvars.get('__name__') + if file is None or name is None: + return None + + if name == '__main__': + # Since the __main__ module can't be reloaded, the module has + # to be imported using its file name. + name = main_module_name() + return name + + mapping_name = utils.dictfind(fvars, mapping) + module_name = modname(fvars) + + def reload_mapping(): + """loadhook to reload mapping and fvars.""" + mod = __import__(module_name) + mapping = getattr(mod, mapping_name, None) + if mapping: + self.fvars = mod.__dict__ + self.mapping = mapping + + self.add_processor(loadhook(Reloader())) + if mapping_name and module_name: + self.add_processor(loadhook(reload_mapping)) + + # load __main__ module usings its filename, so that it can be reloaded. + if main_module_name() and '__main__' in sys.argv: + try: + __import__(main_module_name()) + except ImportError: + pass + + def _load(self): + web.ctx.app_stack.append(self) + + def _unload(self): + web.ctx.app_stack = web.ctx.app_stack[:-1] + + if web.ctx.app_stack: + # this is a sub-application, revert ctx to earlier state. + oldctx = web.ctx.get('_oldctx') + if oldctx: + web.ctx.home = oldctx.home + web.ctx.homepath = oldctx.homepath + web.ctx.path = oldctx.path + web.ctx.fullpath = oldctx.fullpath + + def _cleanup(self): + #@@@ + # Since the CherryPy Webserver uses thread pool, the thread-local state is never cleared. + # This interferes with the other requests. + # clearing the thread-local storage to avoid that. + # see utils.ThreadedDict for details + import threading + t = threading.currentThread() + if hasattr(t, '_d'): + del t._d + + def add_mapping(self, pattern, classname): + self.mapping += (pattern, classname) + + def add_processor(self, processor): + """ + Adds a processor to the application. + + >>> urls = ("/(.*)", "echo") + >>> app = application(urls, globals()) + >>> class echo: + ... def GET(self, name): return name + ... + >>> + >>> def hello(handler): return "hello, " + handler() + ... + >>> app.add_processor(hello) + >>> app.request("/web.py").data + 'hello, web.py' + """ + self.processors.append(processor) + + def request(self, localpart='/', method='GET', data=None, + host="0.0.0.0:8080", headers=None, https=False, **kw): + """Makes request to this application for the specified path and method. + Response will be a storage object with data, status and headers. + + >>> urls = ("/hello", "hello") + >>> app = application(urls, globals()) + >>> class hello: + ... def GET(self): + ... web.header('Content-Type', 'text/plain') + ... return "hello" + ... + >>> response = app.request("/hello") + >>> response.data + 'hello' + >>> response.status + '200 OK' + >>> response.headers['Content-Type'] + 'text/plain' + + To use https, use https=True. + + >>> urls = ("/redirect", "redirect") + >>> app = application(urls, globals()) + >>> class redirect: + ... def GET(self): raise web.seeother("/foo") + ... + >>> response = app.request("/redirect") + >>> response.headers['Location'] + 'http://0.0.0.0:8080/foo' + >>> response = app.request("/redirect", https=True) + >>> response.headers['Location'] + 'https://0.0.0.0:8080/foo' + + The headers argument specifies HTTP headers as a mapping object + such as a dict. + + >>> urls = ('/ua', 'uaprinter') + >>> class uaprinter: + ... def GET(self): + ... return 'your user-agent is ' + web.ctx.env['HTTP_USER_AGENT'] + ... + >>> app = application(urls, globals()) + >>> app.request('/ua', headers = { + ... 'User-Agent': 'a small jumping bean/1.0 (compatible)' + ... }).data + 'your user-agent is a small jumping bean/1.0 (compatible)' + + """ + path, maybe_query = urllib.splitquery(localpart) + query = maybe_query or "" + + if 'env' in kw: + env = kw['env'] + else: + env = {} + env = dict(env, HTTP_HOST=host, REQUEST_METHOD=method, PATH_INFO=path, QUERY_STRING=query, HTTPS=str(https)) + headers = headers or {} + + for k, v in headers.items(): + env['HTTP_' + k.upper().replace('-', '_')] = v + + if 'HTTP_CONTENT_LENGTH' in env: + env['CONTENT_LENGTH'] = env.pop('HTTP_CONTENT_LENGTH') + + if 'HTTP_CONTENT_TYPE' in env: + env['CONTENT_TYPE'] = env.pop('HTTP_CONTENT_TYPE') + + if method in ["POST", "PUT"]: + data = data or '' + import StringIO + if isinstance(data, dict): + q = urllib.urlencode(data) + else: + q = data + env['wsgi.input'] = StringIO.StringIO(q) + if not env.get('CONTENT_TYPE', '').lower().startswith('multipart/') and 'CONTENT_LENGTH' not in env: + env['CONTENT_LENGTH'] = len(q) + response = web.storage() + def start_response(status, headers): + response.status = status + response.headers = dict(headers) + response.header_items = headers + response.data = "".join(self.wsgifunc()(env, start_response)) + return response + + def browser(self): + import browser + return browser.AppBrowser(self) + + def handle(self): + fn, args = self._match(self.mapping, web.ctx.path) + return self._delegate(fn, self.fvars, args) + + def handle_with_processors(self): + def process(processors): + try: + if processors: + p, processors = processors[0], processors[1:] + return p(lambda: process(processors)) + else: + return self.handle() + except web.HTTPError: + raise + except (KeyboardInterrupt, SystemExit): + raise + except: + print >> web.debug, traceback.format_exc() + raise self.internalerror() + + # processors must be applied in the resvere order. (??) + return process(self.processors) + + def wsgifunc(self, *middleware): + """Returns a WSGI-compatible function for this application.""" + def peep(iterator): + """Peeps into an iterator by doing an iteration + and returns an equivalent iterator. + """ + # wsgi requires the headers first + # so we need to do an iteration + # and save the result for later + try: + firstchunk = iterator.next() + except StopIteration: + firstchunk = '' + + return itertools.chain([firstchunk], iterator) + + def is_generator(x): return x and hasattr(x, 'next') + + def wsgi(env, start_resp): + # clear threadlocal to avoid inteference of previous requests + self._cleanup() + + self.load(env) + try: + # allow uppercase methods only + if web.ctx.method.upper() != web.ctx.method: + raise web.nomethod() + + result = self.handle_with_processors() + if is_generator(result): + result = peep(result) + else: + result = [result] + except web.HTTPError, e: + result = [e.data] + + result = web.utf8(iter(result)) + + status, headers = web.ctx.status, web.ctx.headers + start_resp(status, headers) + + def cleanup(): + self._cleanup() + yield '' # force this function to be a generator + + return itertools.chain(result, cleanup()) + + for m in middleware: + wsgi = m(wsgi) + + return wsgi + + def run(self, *middleware): + """ + Starts handling requests. If called in a CGI or FastCGI context, it will follow + that protocol. If called from the command line, it will start an HTTP + server on the port named in the first command line argument, or, if there + is no argument, on port 8080. + + `middleware` is a list of WSGI middleware which is applied to the resulting WSGI + function. + """ + return wsgi.runwsgi(self.wsgifunc(*middleware)) + + def cgirun(self, *middleware): + """ + Return a CGI handler. This is mostly useful with Google App Engine. + There you can just do: + + main = app.cgirun() + """ + wsgiapp = self.wsgifunc(*middleware) + + try: + from google.appengine.ext.webapp.util import run_wsgi_app + return run_wsgi_app(wsgiapp) + except ImportError: + # we're not running from within Google App Engine + return wsgiref.handlers.CGIHandler().run(wsgiapp) + + def load(self, env): + """Initializes ctx using env.""" + ctx = web.ctx + ctx.clear() + ctx.status = '200 OK' + ctx.headers = [] + ctx.output = '' + ctx.environ = ctx.env = env + ctx.host = env.get('HTTP_HOST') + + if env.get('wsgi.url_scheme') in ['http', 'https']: + ctx.protocol = env['wsgi.url_scheme'] + elif env.get('HTTPS', '').lower() in ['on', 'true', '1']: + ctx.protocol = 'https' + else: + ctx.protocol = 'http' + ctx.homedomain = ctx.protocol + '://' + env.get('HTTP_HOST', '[unknown]') + ctx.homepath = os.environ.get('REAL_SCRIPT_NAME', env.get('SCRIPT_NAME', '')) + ctx.home = ctx.homedomain + ctx.homepath + #@@ home is changed when the request is handled to a sub-application. + #@@ but the real home is required for doing absolute redirects. + ctx.realhome = ctx.home + ctx.ip = env.get('REMOTE_ADDR') + ctx.method = env.get('REQUEST_METHOD') + ctx.path = env.get('PATH_INFO') + # http://trac.lighttpd.net/trac/ticket/406 requires: + if env.get('SERVER_SOFTWARE', '').startswith('lighttpd/'): + ctx.path = lstrips(env.get('REQUEST_URI').split('?')[0], ctx.homepath) + # Apache and CherryPy webservers unquote the url but lighttpd doesn't. + # unquote explicitly for lighttpd to make ctx.path uniform across all servers. + ctx.path = urllib.unquote(ctx.path) + + if env.get('QUERY_STRING'): + ctx.query = '?' + env.get('QUERY_STRING', '') + else: + ctx.query = '' + + ctx.fullpath = ctx.path + ctx.query + + for k, v in ctx.iteritems(): + if isinstance(v, str): + ctx[k] = safeunicode(v) + + # status must always be str + ctx.status = '200 OK' + + ctx.app_stack = [] + + def _delegate(self, f, fvars, args=[]): + def handle_class(cls): + meth = web.ctx.method + if meth == 'HEAD' and not hasattr(cls, meth): + meth = 'GET' + if not hasattr(cls, meth): + raise web.nomethod(cls) + tocall = getattr(cls(), meth) + return tocall(*args) + + def is_class(o): return isinstance(o, (types.ClassType, type)) + + if f is None: + raise web.notfound() + elif isinstance(f, application): + return f.handle_with_processors() + elif is_class(f): + return handle_class(f) + elif isinstance(f, basestring): + if f.startswith('redirect '): + url = f.split(' ', 1)[1] + if web.ctx.method == "GET": + x = web.ctx.env.get('QUERY_STRING', '') + if x: + url += '?' + x + raise web.redirect(url) + elif '.' in f: + x = f.split('.') + mod, cls = '.'.join(x[:-1]), x[-1] + mod = __import__(mod, globals(), locals(), [""]) + cls = getattr(mod, cls) + else: + cls = fvars[f] + return handle_class(cls) + elif hasattr(f, '__call__'): + return f() + else: + return web.notfound() + + def _match(self, mapping, value): + for pat, what in utils.group(mapping, 2): + if isinstance(what, application): + if value.startswith(pat): + f = lambda: self._delegate_sub_application(pat, what) + return f, None + else: + continue + elif isinstance(what, basestring): + what, result = utils.re_subm('^' + pat + '$', what, value) + else: + result = utils.re_compile('^' + pat + '$').match(value) + + if result: # it's a match + return what, [x for x in result.groups()] + return None, None + + def _delegate_sub_application(self, dir, app): + """Deletes request to sub application `app` rooted at the directory `dir`. + The home, homepath, path and fullpath values in web.ctx are updated to mimic request + to the subapp and are restored after it is handled. + + @@Any issues with when used with yield? + """ + web.ctx._oldctx = web.storage(web.ctx) + web.ctx.home += dir + web.ctx.homepath += dir + web.ctx.path = web.ctx.path[len(dir):] + web.ctx.fullpath = web.ctx.fullpath[len(dir):] + return app.handle_with_processors() + + def get_parent_app(self): + if self in web.ctx.app_stack: + index = web.ctx.app_stack.index(self) + if index > 0: + return web.ctx.app_stack[index-1] + + def notfound(self): + """Returns HTTPError with '404 not found' message""" + parent = self.get_parent_app() + if parent: + return parent.notfound() + else: + return web._NotFound() + + def internalerror(self): + """Returns HTTPError with '500 internal error' message""" + parent = self.get_parent_app() + if parent: + return parent.internalerror() + elif web.config.get('debug'): + import debugerror + return debugerror.debugerror() + else: + return web._InternalError() + +class auto_application(application): + """Application similar to `application` but urls are constructed + automatiacally using metaclass. + + >>> app = auto_application() + >>> class hello(app.page): + ... def GET(self): return "hello, world" + ... + >>> class foo(app.page): + ... path = '/foo/.*' + ... def GET(self): return "foo" + >>> app.request("/hello").data + 'hello, world' + >>> app.request('/foo/bar').data + 'foo' + """ + def __init__(self): + application.__init__(self) + + class metapage(type): + def __init__(klass, name, bases, attrs): + type.__init__(klass, name, bases, attrs) + path = attrs.get('path', '/' + name) + + # path can be specified as None to ignore that class + # typically required to create a abstract base class. + if path is not None: + self.add_mapping(path, klass) + + class page: + path = None + __metaclass__ = metapage + + self.page = page + +# The application class already has the required functionality of subdir_application +subdir_application = application + +class subdomain_application(application): + """ + Application to delegate requests based on the host. + + >>> urls = ("/hello", "hello") + >>> app = application(urls, globals()) + >>> class hello: + ... def GET(self): return "hello" + >>> + >>> mapping = (r"hello\.example\.com", app) + >>> app2 = subdomain_application(mapping) + >>> app2.request("/hello", host="hello.example.com").data + 'hello' + >>> response = app2.request("/hello", host="something.example.com") + >>> response.status + '404 Not Found' + >>> response.data + 'not found' + """ + def handle(self): + host = web.ctx.host.split(':')[0] #strip port + fn, args = self._match(self.mapping, host) + return self._delegate(fn, self.fvars, args) + + def _match(self, mapping, value): + for pat, what in utils.group(mapping, 2): + if isinstance(what, basestring): + what, result = utils.re_subm('^' + pat + '$', what, value) + else: + result = utils.re_compile('^' + pat + '$').match(value) + + if result: # it's a match + return what, [x for x in result.groups()] + return None, None + +def loadhook(h): + """ + Converts a load hook into an application processor. + + >>> app = auto_application() + >>> def f(): "something done before handling request" + ... + >>> app.add_processor(loadhook(f)) + """ + def processor(handler): + h() + return handler() + + return processor + +def unloadhook(h): + """ + Converts an unload hook into an application processor. + + >>> app = auto_application() + >>> def f(): "something done after handling request" + ... + >>> app.add_processor(unloadhook(f)) + """ + def processor(handler): + try: + result = handler() + is_generator = result and hasattr(result, 'next') + except: + # run the hook even when handler raises some exception + h() + raise + + if is_generator: + return wrap(result) + else: + h() + return result + + def wrap(result): + def next(): + try: + return result.next() + except: + # call the hook at the and of iterator + h() + raise + + result = iter(result) + while True: + yield next() + + return processor + +def autodelegate(prefix=''): + """ + Returns a method that takes one argument and calls the method named prefix+arg, + calling `notfound()` if there isn't one. Example: + + urls = ('/prefs/(.*)', 'prefs') + + class prefs: + GET = autodelegate('GET_') + def GET_password(self): pass + def GET_privacy(self): pass + + `GET_password` would get called for `/prefs/password` while `GET_privacy` for + `GET_privacy` gets called for `/prefs/privacy`. + + If a user visits `/prefs/password/change` then `GET_password(self, '/change')` + is called. + """ + def internal(self, arg): + if '/' in arg: + first, rest = arg.split('/', 1) + func = prefix + first + args = ['/' + rest] + else: + func = prefix + arg + args = [] + + if hasattr(self, func): + try: + return getattr(self, func)(*args) + except TypeError: + raise web.notfound() + else: + raise web.notfound() + return internal + +class Reloader: + """Checks to see if any loaded modules have changed on disk and, + if so, reloads them. + """ + def __init__(self): + self.mtimes = {} + + def __call__(self): + for mod in sys.modules.values(): + self.check(mod) + + def check(self, mod): + try: + mtime = os.stat(mod.__file__).st_mtime + except (AttributeError, OSError, IOError): + return + if mod.__file__.endswith('.pyc') and os.path.exists(mod.__file__[:-1]): + mtime = max(os.stat(mod.__file__[:-1]).st_mtime, mtime) + + if mod not in self.mtimes: + self.mtimes[mod] = mtime + elif self.mtimes[mod] < mtime: + try: + reload(mod) + self.mtimes[mod] = mtime + except ImportError: + pass + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/contrib/web/browser.py b/contrib/web/browser.py new file mode 100644 index 0000000..66d859e --- /dev/null +++ b/contrib/web/browser.py @@ -0,0 +1,236 @@ +"""Browser to test web applications. +(from web.py) +""" +from utils import re_compile +from net import htmlunquote + +import httplib, urllib, urllib2 +import copy +from StringIO import StringIO + +DEBUG = False + +__all__ = [ + "BrowserError", + "Browser", "AppBrowser", + "AppHandler" +] + +class BrowserError(Exception): + pass + +class Browser: + def __init__(self): + import cookielib + self.cookiejar = cookielib.CookieJar() + self._cookie_processor = urllib2.HTTPCookieProcessor(self.cookiejar) + self.form = None + + self.url = "http://0.0.0.0:8080/" + self.path = "/" + + self.status = None + self.data = None + self._response = None + self._forms = None + + def reset(self): + """Clears all cookies and history.""" + self.cookiejar.clear() + + def build_opener(self): + """Builds the opener using urllib2.build_opener. + Subclasses can override this function to prodive custom openers. + """ + return urllib2.build_opener() + + def do_request(self, req): + if DEBUG: + print 'requesting', req.get_method(), req.get_full_url() + opener = self.build_opener() + opener.add_handler(self._cookie_processor) + try: + self._response = opener.open(req) + except urllib2.HTTPError, e: + self._response = e + + self.url = self._response.geturl() + self.path = urllib2.Request(self.url).get_selector() + self.data = self._response.read() + self.status = self._response.code + self._forms = None + self.form = None + return self.get_response() + + def open(self, url, data=None, headers={}): + """Opens the specified url.""" + url = urllib.basejoin(self.url, url) + req = urllib2.Request(url, data, headers) + return self.do_request(req) + + def show(self): + """Opens the current page in real web browser.""" + f = open('page.html', 'w') + f.write(self.data) + f.close() + + import webbrowser, os + url = 'file://' + os.path.abspath('page.html') + webbrowser.open(url) + + def get_response(self): + """Returns a copy of the current response.""" + return urllib.addinfourl(StringIO(self.data), self._response.info(), self._response.geturl()) + + def get_soup(self): + """Returns beautiful soup of the current document.""" + import BeautifulSoup + return BeautifulSoup.BeautifulSoup(self.data) + + def get_text(self, e=None): + """Returns content of e or the current document as plain text.""" + e = e or self.get_soup() + return ''.join([htmlunquote(c) for c in e.recursiveChildGenerator() if isinstance(c, unicode)]) + + def _get_links(self): + soup = self.get_soup() + return [a for a in soup.findAll(name='a')] + + def get_links(self, text=None, text_regex=None, url=None, url_regex=None, predicate=None): + """Returns all links in the document.""" + return self._filter_links(self._get_links(), + text=text, text_regex=text_regex, url=url, url_regex=url_regex, predicate=predicate) + + def follow_link(self, link=None, text=None, text_regex=None, url=None, url_regex=None, predicate=None): + if link is None: + links = self._filter_links(self.get_links(), + text=text, text_regex=text_regex, url=url, url_regex=url_regex, predicate=predicate) + link = links and links[0] + + if link: + return self.open(link['href']) + else: + raise BrowserError("No link found") + + def find_link(self, text=None, text_regex=None, url=None, url_regex=None, predicate=None): + links = self._filter_links(self.get_links(), + text=text, text_regex=text_regex, url=url, url_regex=url_regex, predicate=predicate) + return links and links[0] or None + + def _filter_links(self, links, + text=None, text_regex=None, + url=None, url_regex=None, + predicate=None): + predicates = [] + if text is not None: + predicates.append(lambda link: link.string == text) + if text_regex is not None: + predicates.append(lambda link: re_compile(text_regex).search(link.string or '')) + if url is not None: + predicates.append(lambda link: link.get('href') == url) + if url_regex is not None: + predicates.append(lambda link: re_compile(url_regex).search(link.get('href', ''))) + if predicate: + predicate.append(predicate) + + def f(link): + for p in predicates: + if not p(link): + return False + return True + + return [link for link in links if f(link)] + + def get_forms(self): + """Returns all forms in the current document. + The returned form objects implement the ClientForm.HTMLForm interface. + """ + if self._forms is None: + import ClientForm + self._forms = ClientForm.ParseResponse(self.get_response(), backwards_compat=False) + return self._forms + + def select_form(self, name=None, predicate=None, index=0): + """Selects the specified form.""" + forms = self.get_forms() + + if name is not None: + forms = [f for f in forms if f.name == name] + if predicate: + forms = [f for f in forms if predicate(f)] + + if forms: + self.form = forms[index] + return self.form + else: + raise BrowserError("No form selected.") + + def submit(self, **kw): + """submits the currently selected form.""" + if self.form is None: + raise BrowserError("No form selected.") + req = self.form.click(**kw) + return self.do_request(req) + + def __getitem__(self, key): + return self.form[key] + + def __setitem__(self, key, value): + self.form[key] = value + +class AppBrowser(Browser): + """Browser interface to test web.py apps. + + b = AppBrowser(app) + b.open('/') + b.follow_link(text='Login') + + b.select_form(name='login') + b['username'] = 'joe' + b['password'] = 'secret' + b.submit() + + assert b.path == '/' + assert 'Welcome joe' in b.get_text() + """ + def __init__(self, app): + Browser.__init__(self) + self.app = app + + def build_opener(self): + return urllib2.build_opener(AppHandler(self.app)) + +class AppHandler(urllib2.HTTPHandler): + """urllib2 handler to handle requests using web.py application.""" + handler_order = 100 + + def __init__(self, app): + self.app = app + + def http_open(self, req): + result = self.app.request( + localpart=req.get_selector(), + method=req.get_method(), + host=req.get_host(), + data=req.get_data(), + headers=dict(req.header_items()), + https=req.get_type() == "https" + ) + return self._make_response(result, req.get_full_url()) + + def https_open(self, req): + return self.http_open(req) + + try: + https_request = urllib2.HTTPHandler.do_request_ + except AttributeError: + # for python 2.3 + pass + + def _make_response(self, result, url): + data = "\r\n".join(["%s: %s" % (k, v) for k, v in result.header_items]) + headers = httplib.HTTPMessage(StringIO(data)) + response = urllib.addinfourl(StringIO(result.data), headers, url) + code, msg = result.status.split(None, 1) + response.code, response.msg = int(code), msg + return response diff --git a/contrib/web/contrib/__init__.py b/contrib/web/contrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/contrib/web/contrib/template.py b/contrib/web/contrib/template.py new file mode 100644 index 0000000..7495d39 --- /dev/null +++ b/contrib/web/contrib/template.py @@ -0,0 +1,131 @@ +""" +Interface to various templating engines. +""" +import os.path + +__all__ = [ + "render_cheetah", "render_genshi", "render_mako", + "cache", +] + +class render_cheetah: + """Rendering interface to Cheetah Templates. + + Example: + + render = render_cheetah('templates') + render.hello(name="cheetah") + """ + def __init__(self, path): + # give error if Chetah is not installed + from Cheetah.Template import Template + self.path = path + + def __getattr__(self, name): + from Cheetah.Template import Template + path = os.path.join(self.path, name + ".html") + + def template(**kw): + t = Template(file=path, searchList=[kw]) + return t.respond() + + return template + +class render_genshi: + """Rendering interface genshi templates. + Example: + + for xml/html templates. + + render = render_genshi(['templates/']) + render.hello(name='genshi') + + For text templates: + + render = render_genshi(['templates/'], type='text') + render.hello(name='genshi') + """ + + def __init__(self, *a, **kwargs): + from genshi.template import TemplateLoader + + self._type = kwargs.pop('type', None) + self._loader = TemplateLoader(*a, **kwargs) + + def __getattr__(self, name): + # Assuming all templates are html + path = name + ".html" + + if self._type == "text": + from genshi.template import TextTemplate + cls = TextTemplate + type = "text" + else: + cls = None + type = None + + t = self._loader.load(path, cls=cls) + def template(**kw): + stream = t.generate(**kw) + if type: + return stream.render(type) + else: + return stream.render() + return template + +class render_jinja: + """Rendering interface to Jinja2 Templates + + Example: + + render= render_jinja('templates') + render.hello(name='jinja2') + """ + def __init__(self, *a, **kwargs): + extensions = kwargs.pop('extensions', []) + globals = kwargs.pop('globals', {}) + + from jinja2 import Environment,FileSystemLoader + self._lookup = Environment(loader=FileSystemLoader(*a, **kwargs), extensions=extensions) + self._lookup.globals.update(globals) + + def __getattr__(self, name): + # Assuming all templates end with .html + path = name + '.html' + t = self._lookup.get_template(path) + return t.render + +class render_mako: + """Rendering interface to Mako Templates. + + Example: + + render = render_mako(directories=['templates']) + render.hello(name="mako") + """ + def __init__(self, *a, **kwargs): + from mako.lookup import TemplateLookup + self._lookup = TemplateLookup(*a, **kwargs) + + def __getattr__(self, name): + # Assuming all templates are html + path = name + ".html" + t = self._lookup.get_template(path) + return t.render + +class cache: + """Cache for any rendering interface. + + Example: + + render = cache(render_cheetah("templates/")) + render.hello(name='cache') + """ + def __init__(self, render): + self._render = render + self._cache = {} + + def __getattr__(self, name): + if name not in self._cache: + self._cache[name] = getattr(self._render, name) + return self._cache[name] diff --git a/contrib/web/db.py b/contrib/web/db.py new file mode 100644 index 0000000..91ea8d6 --- /dev/null +++ b/contrib/web/db.py @@ -0,0 +1,1172 @@ +""" +Database API +(part of web.py) +""" + +__all__ = [ + "UnknownParamstyle", "UnknownDB", "TransactionError", + "sqllist", "sqlors", "reparam", "sqlquote", + "SQLQuery", "SQLParam", "sqlparam", + "SQLLiteral", "sqlliteral", + "database", 'DB', +] + +import time +try: + import datetime +except ImportError: + datetime = None + +try: set +except NameError: + from sets import Set as set + +from utils import threadeddict, storage, iters, iterbetter, safestr, safeunicode + +try: + # db module can work independent of web.py + from webapi import debug, config +except: + import sys + debug = sys.stderr + config = storage() + +class UnknownDB(Exception): + """raised for unsupported dbms""" + pass + +class _ItplError(ValueError): + def __init__(self, text, pos): + ValueError.__init__(self) + self.text = text + self.pos = pos + def __str__(self): + return "unfinished expression in %s at char %d" % ( + repr(self.text), self.pos) + +class TransactionError(Exception): pass + +class UnknownParamstyle(Exception): + """ + raised for unsupported db paramstyles + + (currently supported: qmark, numeric, format, pyformat) + """ + pass + +class SQLParam: + """ + Parameter in SQLQuery. + + >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam("joe")]) + >>> q + + >>> q.query() + 'SELECT * FROM test WHERE name=%s' + >>> q.values() + ['joe'] + """ + def __init__(self, value): + self.value = value + + def get_marker(self, paramstyle='pyformat'): + if paramstyle == 'qmark': + return '?' + elif paramstyle == 'numeric': + return ':1' + elif paramstyle is None or paramstyle in ['format', 'pyformat']: + return '%s' + raise UnknownParamstyle, paramstyle + + def sqlquery(self): + return SQLQuery([self]) + + def __add__(self, other): + return self.sqlquery() + other + + def __radd__(self, other): + return other + self.sqlquery() + + def __str__(self): + return str(self.value) + + def __repr__(self): + return '' % repr(self.value) + +sqlparam = SQLParam + +class SQLQuery: + """ + You can pass this sort of thing as a clause in any db function. + Otherwise, you can pass a dictionary to the keyword argument `vars` + and the function will call reparam for you. + + Internally, consists of `items`, which is a list of strings and + SQLParams, which get concatenated to produce the actual query. + """ + # tested in sqlquote's docstring + def __init__(self, items=[]): + r"""Creates a new SQLQuery. + + >>> SQLQuery("x") + + >>> q = SQLQuery(['SELECT * FROM ', 'test', ' WHERE x=', SQLParam(1)]) + >>> q + + >>> q.query(), q.values() + ('SELECT * FROM test WHERE x=%s', [1]) + >>> SQLQuery(SQLParam(1)) + + """ + if isinstance(items, list): + self.items = items + elif isinstance(items, SQLParam): + self.items = [items] + elif isinstance(items, SQLQuery): + self.items = list(items.items) + else: + self.items = [items] + + # Take care of SQLLiterals + for i, item in enumerate(self.items): + if isinstance(item, SQLParam) and isinstance(item.value, SQLLiteral): + self.items[i] = item.value.v + + def __add__(self, other): + if isinstance(other, basestring): + items = [other] + elif isinstance(other, SQLQuery): + items = other.items + else: + return NotImplemented + return SQLQuery(self.items + items) + + def __radd__(self, other): + if isinstance(other, basestring): + items = [other] + else: + return NotImplemented + + return SQLQuery(items + self.items) + + def __iadd__(self, other): + if isinstance(other, basestring): + items = [other] + elif isinstance(other, SQLQuery): + items = other.items + else: + return NotImplemented + self.items.extend(items) + return self + + def __len__(self): + return len(self.query()) + + def query(self, paramstyle=None): + """ + Returns the query part of the sql query. + >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')]) + >>> q.query() + 'SELECT * FROM test WHERE name=%s' + >>> q.query(paramstyle='qmark') + 'SELECT * FROM test WHERE name=?' + """ + s = [] + for x in self.items: + if isinstance(x, SQLParam): + x = x.get_marker(paramstyle) + s.append(safestr(x)) + else: + x = safestr(x) + # automatically escape % characters in the query + # For backward compatability, ignore escaping when the query looks already escaped + if paramstyle in ['format', 'pyformat']: + if '%' in x and '%%' not in x: + x = x.replace('%', '%%') + s.append(x) + return "".join(s) + + def values(self): + """ + Returns the values of the parameters used in the sql query. + >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')]) + >>> q.values() + ['joe'] + """ + return [i.value for i in self.items if isinstance(i, SQLParam)] + + def join(items, sep=' '): + """ + Joins multiple queries. + + >>> SQLQuery.join(['a', 'b'], ', ') + + """ + if len(items) == 0: + return SQLQuery("") + + q = SQLQuery(items[0]) + for item in items[1:]: + q += sep + q += item + return q + + join = staticmethod(join) + + def _str(self): + try: + return self.query() % tuple([sqlify(x) for x in self.values()]) + except (ValueError, TypeError): + return self.query() + + def __str__(self): + return safestr(self._str()) + + def __unicode__(self): + return safeunicode(self._str()) + + def __repr__(self): + return '' % repr(str(self)) + +class SQLLiteral: + """ + Protects a string from `sqlquote`. + + >>> sqlquote('NOW()') + + >>> sqlquote(SQLLiteral('NOW()')) + + """ + def __init__(self, v): + self.v = v + + def __repr__(self): + return self.v + +sqlliteral = SQLLiteral + +def _sqllist(values): + """ + >>> _sqllist([1, 2, 3]) + + """ + items = [] + items.append('(') + for i, v in enumerate(values): + if i != 0: + items.append(', ') + items.append(sqlparam(v)) + items.append(')') + return SQLQuery(items) + +def reparam(string_, dictionary): + """ + Takes a string and a dictionary and interpolates the string + using values from the dictionary. Returns an `SQLQuery` for the result. + + >>> reparam("s = $s", dict(s=True)) + + >>> reparam("s IN $s", dict(s=[1, 2])) + + """ + dictionary = dictionary.copy() # eval mucks with it + vals = [] + result = [] + for live, chunk in _interpolate(string_): + if live: + v = eval(chunk, dictionary) + result.append(sqlquote(v)) + else: + result.append(chunk) + return SQLQuery.join(result, '') + +def sqlify(obj): + """ + converts `obj` to its proper SQL version + + >>> sqlify(None) + 'NULL' + >>> sqlify(True) + "'t'" + >>> sqlify(3) + '3' + """ + # because `1 == True and hash(1) == hash(True)` + # we have to do this the hard way... + + if obj is None: + return 'NULL' + elif obj is True: + return "'t'" + elif obj is False: + return "'f'" + elif datetime and isinstance(obj, datetime.datetime): + return repr(obj.isoformat()) + else: + if isinstance(obj, unicode): obj = obj.encode('utf8') + return repr(obj) + +def sqllist(lst): + """ + Converts the arguments for use in something like a WHERE clause. + + >>> sqllist(['a', 'b']) + 'a, b' + >>> sqllist('a') + 'a' + >>> sqllist(u'abc') + u'abc' + """ + if isinstance(lst, basestring): + return lst + else: + return ', '.join(lst) + +def sqlors(left, lst): + """ + `left is a SQL clause like `tablename.arg = ` + and `lst` is a list of values. Returns a reparam-style + pair featuring the SQL that ORs together the clause + for each item in the lst. + + >>> sqlors('foo = ', []) + + >>> sqlors('foo = ', [1]) + + >>> sqlors('foo = ', 1) + + >>> sqlors('foo = ', [1,2,3]) + + """ + if isinstance(lst, iters): + lst = list(lst) + ln = len(lst) + if ln == 0: + return SQLQuery("1=2") + if ln == 1: + lst = lst[0] + + if isinstance(lst, iters): + return SQLQuery(['('] + + sum([[left, sqlparam(x), ' OR '] for x in lst], []) + + ['1=2)'] + ) + else: + return left + sqlparam(lst) + +def sqlwhere(dictionary, grouping=' AND '): + """ + Converts a `dictionary` to an SQL WHERE clause `SQLQuery`. + + >>> sqlwhere({'cust_id': 2, 'order_id':3}) + + >>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ') + + >>> sqlwhere({'a': 'a', 'b': 'b'}).query() + 'a = %s AND b = %s' + """ + return SQLQuery.join([k + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping) + +def sqlquote(a): + """ + Ensures `a` is quoted properly for use in a SQL query. + + >>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3) + + >>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote([2, 3]) + + """ + if isinstance(a, list): + return _sqllist(a) + else: + return sqlparam(a).sqlquery() + +class Transaction: + """Database transaction.""" + def __init__(self, ctx): + self.ctx = ctx + self.transaction_count = transaction_count = len(ctx.transactions) + + class transaction_engine: + """Transaction Engine used in top level transactions.""" + def do_transact(self): + ctx.commit(unload=False) + + def do_commit(self): + ctx.commit() + + def do_rollback(self): + ctx.rollback() + + class subtransaction_engine: + """Transaction Engine used in sub transactions.""" + def query(self, q): + db_cursor = ctx.db.cursor() + ctx.db_execute(db_cursor, SQLQuery(q % transaction_count)) + + def do_transact(self): + self.query('SAVEPOINT webpy_sp_%s') + + def do_commit(self): + self.query('RELEASE SAVEPOINT webpy_sp_%s') + + def do_rollback(self): + self.query('ROLLBACK TO SAVEPOINT webpy_sp_%s') + + class dummy_engine: + """Transaction Engine used instead of subtransaction_engine + when sub transactions are not supported.""" + do_transact = do_commit = do_rollback = lambda self: None + + if self.transaction_count: + # nested transactions are not supported in some databases + if self.ctx.get('ignore_nested_transactions'): + self.engine = dummy_engine() + else: + self.engine = subtransaction_engine() + else: + self.engine = transaction_engine() + + self.engine.do_transact() + self.ctx.transactions.append(self) + + def __enter__(self): + return self + + def __exit__(self, exctype, excvalue, traceback): + if exctype is not None: + self.rollback() + else: + self.commit() + + def commit(self): + if len(self.ctx.transactions) > self.transaction_count: + self.engine.do_commit() + self.ctx.transactions = self.ctx.transactions[:self.transaction_count] + + def rollback(self): + if len(self.ctx.transactions) > self.transaction_count: + self.engine.do_rollback() + self.ctx.transactions = self.ctx.transactions[:self.transaction_count] + +class DB: + """Database""" + def __init__(self, db_module, keywords): + """Creates a database. + """ + # some DB implementaions take optional paramater `driver` to use a specific driver modue + # but it should not be passed to connect + keywords.pop('driver', None) + + self.db_module = db_module + self.keywords = keywords + + + self._ctx = threadeddict() + # flag to enable/disable printing queries + self.printing = config.get('debug', False) + self.supports_multiple_insert = False + + try: + import DBUtils + # enable pooling if DBUtils module is available. + self.has_pooling = True + except ImportError: + self.has_pooling = False + + # Pooling can be disabled by passing pooling=False in the keywords. + self.has_pooling = self.keywords.pop('pooling', True) and self.has_pooling + + def _getctx(self): + if not self._ctx.get('db'): + self._load_context(self._ctx) + return self._ctx + ctx = property(_getctx) + + def _load_context(self, ctx): + ctx.dbq_count = 0 + ctx.transactions = [] # stack of transactions + + if self.has_pooling: + ctx.db = self._connect_with_pooling(self.keywords) + else: + ctx.db = self._connect(self.keywords) + ctx.db_execute = self._db_execute + + if not hasattr(ctx.db, 'commit'): + ctx.db.commit = lambda: None + + if not hasattr(ctx.db, 'rollback'): + ctx.db.rollback = lambda: None + + def commit(unload=True): + # do db commit and release the connection if pooling is enabled. + ctx.db.commit() + if unload and self.has_pooling: + self._unload_context(self._ctx) + + def rollback(): + # do db rollback and release the connection if pooling is enabled. + ctx.db.rollback() + if self.has_pooling: + self._unload_context(self._ctx) + + ctx.commit = commit + ctx.rollback = rollback + + def _unload_context(self, ctx): + del ctx.db + + def _connect(self, keywords): + return self.db_module.connect(**keywords) + + def _connect_with_pooling(self, keywords): + def get_pooled_db(): + from DBUtils import PooledDB + + # In DBUtils 0.9.3, `dbapi` argument is renamed as `creator` + # see Bug#122112 + + if PooledDB.__version__.split('.') < '0.9.3'.split('.'): + return PooledDB.PooledDB(dbapi=self.db_module, **keywords) + else: + return PooledDB.PooledDB(creator=self.db_module, **keywords) + + if getattr(self, '_pooleddb', None) is None: + self._pooleddb = get_pooled_db() + + return self._pooleddb.connection() + + def _db_cursor(self): + return self.ctx.db.cursor() + + def _param_marker(self): + """Returns parameter marker based on paramstyle attribute if this database.""" + style = getattr(self, 'paramstyle', 'pyformat') + + if style == 'qmark': + return '?' + elif style == 'numeric': + return ':1' + elif style in ['format', 'pyformat']: + return '%s' + raise UnknownParamstyle, style + + def _db_execute(self, cur, sql_query): + """executes an sql query""" + self.ctx.dbq_count += 1 + + try: + a = time.time() + paramstyle = getattr(self, 'paramstyle', 'pyformat') + out = cur.execute(sql_query.query(paramstyle), sql_query.values()) + b = time.time() + except: + if self.printing: + print >> debug, 'ERR:', str(sql_query) + if self.ctx.transactions: + self.ctx.transactions[-1].rollback() + else: + self.ctx.rollback() + raise + + if self.printing: + print >> debug, '%s (%s): %s' % (round(b-a, 2), self.ctx.dbq_count, str(sql_query)) + return out + + def _where(self, where, vars): + if isinstance(where, (int, long)): + where = "id = " + sqlparam(where) + #@@@ for backward-compatibility + elif isinstance(where, (list, tuple)) and len(where) == 2: + where = SQLQuery(where[0], where[1]) + elif isinstance(where, SQLQuery): + pass + else: + where = reparam(where, vars) + return where + + def query(self, sql_query, vars=None, processed=False, _test=False): + """ + Execute SQL query `sql_query` using dictionary `vars` to interpolate it. + If `processed=True`, `vars` is a `reparam`-style list to use + instead of interpolating. + + >>> db = DB(None, {}) + >>> db.query("SELECT * FROM foo", _test=True) + + >>> db.query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True) + + >>> db.query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True) + + """ + if vars is None: vars = {} + + if not processed and not isinstance(sql_query, SQLQuery): + sql_query = reparam(sql_query, vars) + + if _test: return sql_query + + db_cursor = self._db_cursor() + self._db_execute(db_cursor, sql_query) + + if db_cursor.description: + names = [x[0] for x in db_cursor.description] + def iterwrapper(): + row = db_cursor.fetchone() + while row: + yield storage(dict(zip(names, row))) + row = db_cursor.fetchone() + out = iterbetter(iterwrapper()) + out.__len__ = lambda: int(db_cursor.rowcount) + out.list = lambda: [storage(dict(zip(names, x))) \ + for x in db_cursor.fetchall()] + else: + out = db_cursor.rowcount + + if not self.ctx.transactions: + self.ctx.commit() + return out + + def select(self, tables, vars=None, what='*', where=None, order=None, group=None, + limit=None, offset=None, _test=False): + """ + Selects `what` from `tables` with clauses `where`, `order`, + `group`, `limit`, and `offset`. Uses vars to interpolate. + Otherwise, each clause can be a SQLQuery. + + >>> db = DB(None, {}) + >>> db.select('foo', _test=True) + + >>> db.select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True) + + """ + if vars is None: vars = {} + sql_clauses = self.sql_clauses(what, tables, where, group, order, limit, offset) + clauses = [self.gen_clause(sql, val, vars) for sql, val in sql_clauses if val is not None] + qout = SQLQuery.join(clauses) + if _test: return qout + return self.query(qout, processed=True) + + def where(self, table, what='*', order=None, group=None, limit=None, + offset=None, _test=False, **kwargs): + """ + Selects from `table` where keys are equal to values in `kwargs`. + + >>> db = DB(None, {}) + >>> db.where('foo', bar_id=3, _test=True) + + >>> db.where('foo', source=2, crust='dewey', _test=True) + + """ + where = [] + for k, v in kwargs.iteritems(): + where.append(k + ' = ' + sqlquote(v)) + return self.select(table, what=what, order=order, + group=group, limit=limit, offset=offset, _test=_test, + where=SQLQuery.join(where, ' AND ')) + + def sql_clauses(self, what, tables, where, group, order, limit, offset): + return ( + ('SELECT', what), + ('FROM', sqllist(tables)), + ('WHERE', where), + ('GROUP BY', group), + ('ORDER BY', order), + ('LIMIT', limit), + ('OFFSET', offset)) + + def gen_clause(self, sql, val, vars): + if isinstance(val, (int, long)): + if sql == 'WHERE': + nout = 'id = ' + sqlquote(val) + else: + nout = SQLQuery(val) + #@@@ + elif isinstance(val, (list, tuple)) and len(val) == 2: + nout = SQLQuery(val[0], val[1]) # backwards-compatibility + elif isinstance(val, SQLQuery): + nout = val + else: + nout = reparam(val, vars) + + def xjoin(a, b): + if a and b: return a + ' ' + b + else: return a or b + + return xjoin(sql, nout) + + def insert(self, tablename, seqname=None, _test=False, **values): + """ + Inserts `values` into `tablename`. Returns current sequence ID. + Set `seqname` to the ID if it's not the default, or to `False` + if there isn't one. + + >>> db = DB(None, {}) + >>> q = db.insert('foo', name='bob', age=2, created=SQLLiteral('NOW()'), _test=True) + >>> q + + >>> q.query() + 'INSERT INTO foo (age, name, created) VALUES (%s, %s, NOW())' + >>> q.values() + [2, 'bob'] + """ + def q(x): return "(" + x + ")" + + if values: + _keys = SQLQuery.join(values.keys(), ', ') + _values = SQLQuery.join([sqlparam(v) for v in values.values()], ', ') + sql_query = "INSERT INTO %s " % tablename + q(_keys) + ' VALUES ' + q(_values) + else: + sql_query = SQLQuery("INSERT INTO %s DEFAULT VALUES" % tablename) + + if _test: return sql_query + + db_cursor = self._db_cursor() + if seqname is not False: + sql_query = self._process_insert_query(sql_query, tablename, seqname) + + if isinstance(sql_query, tuple): + # for some databases, a separate query has to be made to find + # the id of the inserted row. + q1, q2 = sql_query + self._db_execute(db_cursor, q1) + self._db_execute(db_cursor, q2) + else: + self._db_execute(db_cursor, sql_query) + + try: + out = db_cursor.fetchone()[0] + except Exception: + out = None + + if not self.ctx.transactions: + self.ctx.commit() + return out + + def multiple_insert(self, tablename, values, seqname=None, _test=False): + """ + Inserts multiple rows into `tablename`. The `values` must be a list of dictioanries, + one for each row to be inserted, each with the same set of keys. + Returns the list of ids of the inserted rows. + Set `seqname` to the ID if it's not the default, or to `False` + if there isn't one. + + >>> db = DB(None, {}) + >>> db.supports_multiple_insert = True + >>> values = [{"name": "foo", "email": "foo@example.com"}, {"name": "bar", "email": "bar@example.com"}] + >>> db.multiple_insert('person', values=values, _test=True) + + """ + if not values: + return [] + + if not self.supports_multiple_insert: + out = [self.insert(tablename, seqname=seqname, _test=_test, **v) for v in values] + if seqname is False: + return None + else: + return out + + keys = values[0].keys() + #@@ make sure all keys are valid + + # make sure all rows have same keys. + for v in values: + if v.keys() != keys: + raise ValueError, 'Bad data' + + sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys))) + + data = [] + for row in values: + d = SQLQuery.join([SQLParam(row[k]) for k in keys], ', ') + data.append('(' + d + ')') + sql_query += SQLQuery.join(data, ', ') + + if _test: return sql_query + + db_cursor = self._db_cursor() + if seqname is not False: + sql_query = self._process_insert_query(sql_query, tablename, seqname) + + if isinstance(sql_query, tuple): + # for some databases, a separate query has to be made to find + # the id of the inserted row. + q1, q2 = sql_query + self._db_execute(db_cursor, q1) + self._db_execute(db_cursor, q2) + else: + self._db_execute(db_cursor, sql_query) + + try: + out = db_cursor.fetchone()[0] + out = range(out-len(values)+1, out+1) + except Exception: + out = None + + if not self.ctx.transactions: + self.ctx.commit() + return out + + + def update(self, tables, where, vars=None, _test=False, **values): + """ + Update `tables` with clause `where` (interpolated using `vars`) + and setting `values`. + + >>> db = DB(None, {}) + >>> name = 'Joseph' + >>> q = db.update('foo', where='name = $name', name='bob', age=2, + ... created=SQLLiteral('NOW()'), vars=locals(), _test=True) + >>> q + + >>> q.query() + 'UPDATE foo SET age = %s, name = %s, created = NOW() WHERE name = %s' + >>> q.values() + [2, 'bob', 'Joseph'] + """ + if vars is None: vars = {} + where = self._where(where, vars) + + query = ( + "UPDATE " + sqllist(tables) + + " SET " + sqlwhere(values, ', ') + + " WHERE " + where) + + if _test: return query + + db_cursor = self._db_cursor() + self._db_execute(db_cursor, query) + if not self.ctx.transactions: + self.ctx.commit() + return db_cursor.rowcount + + def delete(self, table, where, using=None, vars=None, _test=False): + """ + Deletes from `table` with clauses `where` and `using`. + + >>> db = DB(None, {}) + >>> name = 'Joe' + >>> db.delete('foo', where='name = $name', vars=locals(), _test=True) + + """ + if vars is None: vars = {} + where = self._where(where, vars) + + q = 'DELETE FROM ' + table + if where: q += ' WHERE ' + where + if using: q += ' USING ' + sqllist(using) + + if _test: return q + + db_cursor = self._db_cursor() + self._db_execute(db_cursor, q) + if not self.ctx.transactions: + self.ctx.commit() + return db_cursor.rowcount + + def _process_insert_query(self, query, tablename, seqname): + return query + + def transaction(self): + """Start a transaction.""" + return Transaction(self.ctx) + +class PostgresDB(DB): + """Postgres driver.""" + def __init__(self, **keywords): + if 'pw' in keywords: + keywords['password'] = keywords.pop('pw') + + db_module = import_driver(["psycopg2", "psycopg", "pgdb"], preferred=keywords.pop('driver', None)) + if db_module.__name__ == "psycopg2": + import psycopg2.extensions + psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) + + # if db is not provided postgres driver will take it from PGDATABASE environment variable + if 'db' in keywords: + keywords['database'] = keywords.pop('db') + + self.dbname = "postgres" + self.paramstyle = db_module.paramstyle + DB.__init__(self, db_module, keywords) + self.supports_multiple_insert = True + self._sequences = None + + def _process_insert_query(self, query, tablename, seqname): + if seqname is None: + # when seqname is not provided guess the seqname and make sure it exists + seqname = tablename + "_id_seq" + if seqname not in self._get_all_sequences(): + seqname = None + + if seqname: + query += "; SELECT currval('%s')" % seqname + + return query + + def _get_all_sequences(self): + """Query postgres to find names of all sequences used in this database.""" + if self._sequences is None: + q = "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'" + self._sequences = set([c.relname for c in self.query(q)]) + return self._sequences + + def _connect(self, keywords): + conn = DB._connect(self, keywords) + conn.set_client_encoding('UTF8') + return conn + + def _connect_with_pooling(self, keywords): + conn = DB._connect_with_pooling(self, keywords) + conn._con._con.set_client_encoding('UTF8') + return conn + +class MySQLDB(DB): + def __init__(self, **keywords): + import MySQLdb as db + if 'pw' in keywords: + keywords['passwd'] = keywords['pw'] + del keywords['pw'] + + if 'charset' not in keywords: + keywords['charset'] = 'utf8' + elif keywords['charset'] is None: + del keywords['charset'] + + self.paramstyle = db.paramstyle = 'pyformat' # it's both, like psycopg + self.dbname = "mysql" + DB.__init__(self, db, keywords) + self.supports_multiple_insert = True + + def _process_insert_query(self, query, tablename, seqname): + return query, SQLQuery('SELECT last_insert_id();') + +def import_driver(drivers, preferred=None): + """Import the first available driver or preferred driver. + """ + if preferred: + drivers = [preferred] + + for d in drivers: + try: + return __import__(d, None, None, ['x']) + except ImportError: + pass + raise ImportError("Unable to import " + " or ".join(drivers)) + +class SqliteDB(DB): + def __init__(self, **keywords): + db = import_driver(["sqlite3", "pysqlite2.dbapi2", "sqlite"], preferred=keywords.pop('driver', None)) + + if db.__name__ in ["sqlite3", "pysqlite2.dbapi2"]: + db.paramstyle = 'qmark' + + self.paramstyle = db.paramstyle + keywords['database'] = keywords.pop('db') + self.dbname = "sqlite" + DB.__init__(self, db, keywords) + + def _process_insert_query(self, query, tablename, seqname): + return query, SQLQuery('SELECT last_insert_rowid();') + + def query(self, *a, **kw): + out = DB.query(self, *a, **kw) + if isinstance(out, iterbetter): + del out.__len__ + return out + +class FirebirdDB(DB): + """Firebird Database. + """ + def __init__(self, **keywords): + try: + import kinterbasdb as db + except Exception: + db = None + pass + if 'pw' in keywords: + keywords['passwd'] = keywords['pw'] + del keywords['pw'] + keywords['database'] = keywords['db'] + del keywords['db'] + DB.__init__(self, db, keywords) + + def delete(self, table, where=None, using=None, vars=None, _test=False): + # firebird doesn't support using clause + using=None + return DB.delete(self, table, where, using, vars, _test) + + def sql_clauses(self, what, tables, where, group, order, limit, offset): + return ( + ('SELECT', ''), + ('FIRST', limit), + ('SKIP', offset), + ('', what), + ('FROM', sqllist(tables)), + ('WHERE', where), + ('GROUP BY', group), + ('ORDER BY', order) + ) + +class MSSQLDB(DB): + def __init__(self, **keywords): + import pymssql as db + if 'pw' in keywords: + keywords['password'] = keywords.pop('pw') + keywords['database'] = keywords.pop('db') + self.dbname = "mssql" + DB.__init__(self, db, keywords) + + def sql_clauses(self, what, tables, where, group, order, limit, offset): + return ( + ('SELECT', what), + ('TOP', limit), + ('FROM', sqllist(tables)), + ('WHERE', where), + ('GROUP BY', group), + ('ORDER BY', order), + ('OFFSET', offset)) + + def _test(self): + """Test LIMIT. + + Fake presence of pymssql module for running tests. + >>> import sys + >>> sys.modules['pymssql'] = sys.modules['sys'] + + MSSQL has TOP clause instead of LIMIT clause. + >>> db = MSSQLDB(db='test', user='joe', pw='secret') + >>> db.select('foo', limit=4, _test=True) + + """ + pass + +class OracleDB(DB): + def __init__(self, **keywords): + import cx_Oracle as db + if 'pw' in keywords: + keywords['password'] = keywords.pop('pw') + + #@@ TODO: use db.makedsn if host, port is specified + keywords['dsn'] = keywords.pop('db') + self.dbname = 'oracle' + db.paramstyle = 'numeric' + self.paramstyle = db.paramstyle + + # oracle doesn't support pooling + keywords.pop('pooling', None) + DB.__init__(self, db, keywords) + + def _process_insert_query(self, query, tablename, seqname): + if seqname is None: + # It is not possible to get seq name from table name in Oracle + return query + else: + return query + "; SELECT %s.currval FROM dual" % seqname + +_databases = {} +def database(dburl=None, **params): + """Creates appropriate database using params. + + Pooling will be enabled if DBUtils module is available. + Pooling can be disabled by passing pooling=False in params. + """ + dbn = params.pop('dbn') + if dbn in _databases: + return _databases[dbn](**params) + else: + raise UnknownDB, dbn + +def register_database(name, clazz): + """ + Register a database. + + >>> class LegacyDB(DB): + ... def __init__(self, **params): + ... pass + ... + >>> register_database('legacy', LegacyDB) + >>> db = database(dbn='legacy', db='test', user='joe', passwd='secret') + """ + _databases[name] = clazz + +register_database('mysql', MySQLDB) +register_database('postgres', PostgresDB) +register_database('sqlite', SqliteDB) +register_database('firebird', FirebirdDB) +register_database('mssql', MSSQLDB) +register_database('oracle', OracleDB) + +def _interpolate(format): + """ + Takes a format string and returns a list of 2-tuples of the form + (boolean, string) where boolean says whether string should be evaled + or not. + + from (public domain, Ka-Ping Yee) + """ + from tokenize import tokenprog + + def matchorfail(text, pos): + match = tokenprog.match(text, pos) + if match is None: + raise _ItplError(text, pos) + return match, match.end() + + namechars = "abcdefghijklmnopqrstuvwxyz" \ + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"; + chunks = [] + pos = 0 + + while 1: + dollar = format.find("$", pos) + if dollar < 0: + break + nextchar = format[dollar + 1] + + if nextchar == "{": + chunks.append((0, format[pos:dollar])) + pos, level = dollar + 2, 1 + while level: + match, pos = matchorfail(format, pos) + tstart, tend = match.regs[3] + token = format[tstart:tend] + if token == "{": + level = level + 1 + elif token == "}": + level = level - 1 + chunks.append((1, format[dollar + 2:pos - 1])) + + elif nextchar in namechars: + chunks.append((0, format[pos:dollar])) + match, pos = matchorfail(format, dollar + 1) + while pos < len(format): + if format[pos] == "." and \ + pos + 1 < len(format) and format[pos + 1] in namechars: + match, pos = matchorfail(format, pos + 1) + elif format[pos] in "([": + pos, level = pos + 1, 1 + while level: + match, pos = matchorfail(format, pos) + tstart, tend = match.regs[3] + token = format[tstart:tend] + if token[0] in "([": + level = level + 1 + elif token[0] in ")]": + level = level - 1 + else: + break + chunks.append((1, format[dollar + 1:pos])) + else: + chunks.append((0, format[pos:dollar + 1])) + pos = dollar + 1 + (nextchar == "$") + + if pos < len(format): + chunks.append((0, format[pos:])) + return chunks + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/contrib/web/debugerror.py b/contrib/web/debugerror.py new file mode 100644 index 0000000..4f0fc6c --- /dev/null +++ b/contrib/web/debugerror.py @@ -0,0 +1,354 @@ +""" +pretty debug errors +(part of web.py) + +portions adapted from Django +Copyright (c) 2005, the Lawrence Journal-World +Used under the modified BSD license: +http://www.xfree86.org/3.3.6/COPYRIGHT2.html#5 +""" + +__all__ = ["debugerror", "djangoerror", "emailerrors"] + +import sys, urlparse, pprint, traceback +from net import websafe +from template import Template +from utils import sendmail, safestr +import webapi as web + +import os, os.path +whereami = os.path.join(os.getcwd(), __file__) +whereami = os.path.sep.join(whereami.split(os.path.sep)[:-1]) +djangoerror_t = """\ +$def with (exception_type, exception_value, frames) + + + + + + $exception_type at $ctx.path + + + + + +$def dicttable (d, kls='req', id=None): + $ items = d and d.items() or [] + $items.sort() + $:dicttable_items(items, kls, id) + +$def dicttable_items(items, kls='req', id=None): + $if items: + + + $for k, v in items: + + +
VariableValue
$k
$prettify(v)
+ $else: +

No data.

+ +
+

$exception_type at $ctx.path

+

$exception_value

+ + + + + + +
Python$frames[0].filename in $frames[0].function, line $frames[0].lineno
Web$ctx.method $ctx.home$ctx.path
+
+
+

Traceback (innermost first)

+
    +$for frame in frames: +
  • + $frame.filename in $frame.function + $if frame.context_line is not None: +
    + $if frame.pre_context: +
      + $for line in frame.pre_context: +
    1. $line
    2. +
    +
    1. $frame.context_line ...
    + $if frame.post_context: +
      + $for line in frame.post_context: +
    1. $line
    2. +
    +
    + + $if frame.vars: +
    + Local vars + $# $inspect.formatargvalues(*inspect.getargvalues(frame['tb'].tb_frame)) +
    + $:dicttable(frame.vars, kls='vars', id=('v' + str(frame.id))) +
  • +
+
+ +
+$if ctx.output or ctx.headers: +

Response so far

+

HEADERS

+ $:dicttable_items(ctx.headers) + +

BODY

+

+ $ctx.output +

+ +

Request information

+ +

INPUT

+$:dicttable(web.input(_unicode=False)) + + +$:dicttable(web.cookies()) + +

META

+$ newctx = [(k, v) for (k, v) in ctx.iteritems() if not k.startswith('_') and not isinstance(v, dict)] +$:dicttable(dict(newctx)) + +

ENVIRONMENT

+$:dicttable(ctx.env) +
+ +
+

+ You're seeing this error because you have web.config.debug + set to True. Set that to False if you don't to see this. +

+
+ + + +""" + +djangoerror_r = None + +def djangoerror(): + def _get_lines_from_file(filename, lineno, context_lines): + """ + Returns context_lines before and after lineno from file. + Returns (pre_context_lineno, pre_context, context_line, post_context). + """ + try: + source = open(filename).readlines() + lower_bound = max(0, lineno - context_lines) + upper_bound = lineno + context_lines + + pre_context = \ + [line.strip('\n') for line in source[lower_bound:lineno]] + context_line = source[lineno].strip('\n') + post_context = \ + [line.strip('\n') for line in source[lineno + 1:upper_bound]] + + return lower_bound, pre_context, context_line, post_context + except (OSError, IOError, IndexError): + return None, [], None, [] + + exception_type, exception_value, tback = sys.exc_info() + frames = [] + while tback is not None: + filename = tback.tb_frame.f_code.co_filename + function = tback.tb_frame.f_code.co_name + lineno = tback.tb_lineno - 1 + + # hack to get correct line number for templates + lineno += tback.tb_frame.f_locals.get("__lineoffset__", 0) + + pre_context_lineno, pre_context, context_line, post_context = \ + _get_lines_from_file(filename, lineno, 7) + + if '__hidetraceback__' not in tback.tb_frame.f_locals: + frames.append(web.storage({ + 'tback': tback, + 'filename': filename, + 'function': function, + 'lineno': lineno, + 'vars': tback.tb_frame.f_locals, + 'id': id(tback), + 'pre_context': pre_context, + 'context_line': context_line, + 'post_context': post_context, + 'pre_context_lineno': pre_context_lineno, + })) + tback = tback.tb_next + frames.reverse() + urljoin = urlparse.urljoin + def prettify(x): + try: + out = pprint.pformat(x) + except Exception, e: + out = '[could not display: <' + e.__class__.__name__ + \ + ': '+str(e)+'>]' + return out + + global djangoerror_r + if djangoerror_r is None: + djangoerror_r = Template(djangoerror_t, filename=__file__, filter=websafe) + + t = djangoerror_r + globals = {'ctx': web.ctx, 'web':web, 'dict':dict, 'str':str, 'prettify': prettify} + t.t.func_globals.update(globals) + return t(exception_type, exception_value, frames) + +def debugerror(): + """ + A replacement for `internalerror` that presents a nice page with lots + of debug information for the programmer. + + (Based on the beautiful 500 page from [Django](http://djangoproject.com/), + designed by [Wilson Miner](http://wilsonminer.com/).) + """ + return web._InternalError(djangoerror()) + +def emailerrors(to_address, olderror, from_address=None): + """ + Wraps the old `internalerror` handler (pass as `olderror`) to + additionally email all errors to `to_address`, to aid in + debugging production websites. + + Emails contain a normal text traceback as well as an + attachment containing the nice `debugerror` page. + """ + from_address = from_address or to_address + + def emailerrors_internal(): + error = olderror() + tb = sys.exc_info() + error_name = tb[0] + error_value = tb[1] + tb_txt = ''.join(traceback.format_exception(*tb)) + path = web.ctx.path + request = web.ctx.method + ' ' + web.ctx.home + web.ctx.fullpath + + message = "\n%s\n\n%s\n\n" % (request, tb_txt) + + sendmail( + "your buggy site <%s>" % from_address, + "the bugfixer <%s>" % to_address, + "bug: %(error_name)s: %(error_value)s (%(path)s)" % locals(), + message, + attachments=[ + dict(filename="bug.html", content=safestr(djangoerror())) + ], + ) + return error + + return emailerrors_internal + +if __name__ == "__main__": + urls = ( + '/', 'index' + ) + from application import application + app = application(urls, globals()) + app.internalerror = debugerror + + class index: + def GET(self): + thisdoesnotexist + + app.run() diff --git a/contrib/web/form.py b/contrib/web/form.py new file mode 100644 index 0000000..4324e6c --- /dev/null +++ b/contrib/web/form.py @@ -0,0 +1,369 @@ +""" +HTML forms +(part of web.py) +""" + +import copy, re +import webapi as web +import utils, net + +def attrget(obj, attr, value=None): + if hasattr(obj, 'has_key') and obj.has_key(attr): return obj[attr] + if hasattr(obj, attr): return getattr(obj, attr) + return value + +class Form: + r""" + HTML form. + + >>> f = Form(Textbox("x")) + >>> f.render() + '\n \n
' + """ + def __init__(self, *inputs, **kw): + self.inputs = inputs + self.valid = True + self.note = None + self.validators = kw.pop('validators', []) + + def __call__(self, x=None): + o = copy.deepcopy(self) + if x: o.validates(x) + return o + + def render(self): + out = '' + out += self.rendernote(self.note) + out += '\n' + + for i in self.inputs: + html = i.pre + i.render() + self.rendernote(i.note) + i.post + if i.is_hidden(): + out += ' \n' % (html) + else: + out += ' \n' % (i.id, net.websafe(i.description), html) + out += "
%s
%s
" + return out + + def render_css(self): + out = [] + out.append(self.rendernote(self.note)) + for i in self.inputs: + if not i.is_hidden(): + out.append('' % (i.id, net.websafe(i.description))) + out.append(i.pre) + out.append(i.render()) + out.append(self.rendernote(i.note)) + out.append(i.post) + out.append('\n') + return ''.join(out) + + def rendernote(self, note): + if note: return '%s' % net.websafe(note) + else: return "" + + def validates(self, source=None, _validate=True, **kw): + source = source or kw or web.input() + out = True + for i in self.inputs: + v = attrget(source, i.name) + if _validate: + out = i.validate(v) and out + else: + i.value = v + if _validate: + out = out and self._validate(source) + self.valid = out + return out + + def _validate(self, value): + self.value = value + for v in self.validators: + if not v.valid(value): + self.note = v.msg + return False + return True + + def fill(self, source=None, **kw): + return self.validates(source, _validate=False, **kw) + + def __getitem__(self, i): + for x in self.inputs: + if x.name == i: return x + raise KeyError, i + + def __getattr__(self, name): + # don't interfere with deepcopy + inputs = self.__dict__.get('inputs') or [] + for x in inputs: + if x.name == name: return x + raise AttributeError, name + + def get(self, i, default=None): + try: + return self[i] + except KeyError: + return default + + def _get_d(self): #@@ should really be form.attr, no? + return utils.storage([(i.name, i.get_value()) for i in self.inputs]) + d = property(_get_d) + +class Input(object): + def __init__(self, name, *validators, **attrs): + self.name = name + self.validators = validators + self.attrs = attrs = AttributeList(attrs) + + self.description = attrs.pop('description', name) + self.value = attrs.pop('value', None) + self.pre = attrs.pop('pre', "") + self.post = attrs.pop('post', "") + self.note = None + + self.id = attrs.setdefault('id', self.get_default_id()) + + if 'class_' in attrs: + attrs['class'] = attrs['class_'] + del attrs['class_'] + + def is_hidden(self): + return False + + def get_type(self): + raise NotImplementedError + + def get_default_id(self): + return self.name + + def validate(self, value): + self.set_value(value) + + for v in self.validators: + if not v.valid(value): + self.note = v.msg + return False + return True + + def set_value(self, value): + self.value = value + + def get_value(self): + return self.value + + def render(self): + attrs = self.attrs.copy() + attrs['type'] = self.get_type() + if self.value is not None: + attrs['value'] = self.value + attrs['name'] = self.name + return '' % attrs + + def rendernote(self, note): + if note: return '%s' % net.websafe(note) + else: return "" + + def addatts(self): + # add leading space for backward-compatibility + return " " + str(self.attrs) + +class AttributeList(dict): + """List of atributes of input. + + >>> a = AttributeList(type='text', name='x', value=20) + >>> a + + """ + def copy(self): + return AttributeList(self) + + def __str__(self): + return " ".join(['%s="%s"' % (k, net.websafe(v)) for k, v in self.items()]) + + def __repr__(self): + return '' % repr(str(self)) + +class Textbox(Input): + """Textbox input. + + >>> Textbox(name='foo', value='bar').render() + '' + >>> Textbox(name='foo', value=0).render() + '' + """ + def get_type(self): + return 'text' + +class Password(Input): + """Password input. + + >>> Password(name='password', value='secret').render() + '' + """ + + def get_type(self): + return 'password' + +class Textarea(Input): + """Textarea input. + + >>> Textarea(name='foo', value='bar').render() + '' + """ + def render(self): + attrs = self.attrs.copy() + attrs['name'] = self.name + value = net.websafe(self.value or '') + return '' % (attrs, value) + +class Dropdown(Input): + r"""Dropdown/select input. + + >>> Dropdown(name='foo', args=['a', 'b', 'c'], value='b').render() + '\n' + >>> Dropdown(name='foo', args=[('a', 'aa'), ('b', 'bb'), ('c', 'cc')], value='b').render() + '\n' + """ + def __init__(self, name, args, *validators, **attrs): + self.args = args + super(Dropdown, self).__init__(name, *validators, **attrs) + + def render(self): + attrs = self.attrs.copy() + attrs['name'] = self.name + + x = '\n' + return x + +class Radio(Input): + def __init__(self, name, args, *validators, **attrs): + self.args = args + super(Radio, self).__init__(name, *validators, **attrs) + + def render(self): + x = '' + for arg in self.args: + if isinstance(arg, (tuple, list)): + value, desc= arg + else: + value, desc = arg, arg + attrs = self.attrs.copy() + attrs['name'] = self.name + attrs['type'] = 'radio' + attrs['value'] = arg + if self.value == arg: + attrs['checked'] = 'checked' + x += ' %s' % (attrs, net.websafe(desc)) + x += '' + return x + +class Checkbox(Input): + """Checkbox input. + + >>> Checkbox('foo', value='bar', checked=True).render() + '' + >>> Checkbox('foo', value='bar').render() + '' + >>> c = Checkbox('foo', value='bar') + >>> c.validate('on') + True + >>> c.render() + '' + """ + def __init__(self, name, *validators, **attrs): + self.checked = attrs.pop('checked', False) + Input.__init__(self, name, *validators, **attrs) + + def get_default_id(self): + value = utils.safestr(self.value or "") + return self.name + '_' + value.replace(' ', '_') + + def render(self): + attrs = self.attrs.copy() + attrs['type'] = 'checkbox' + attrs['name'] = self.name + attrs['value'] = self.value + + if self.checked: + attrs['checked'] = 'checked' + return '' % attrs + + def set_value(self, value): + self.checked = bool(value) + + def get_value(self): + return self.checked + +class Button(Input): + """HTML Button. + + >>> Button("save").render() + '' + >>> Button("action", value="save", html="Save Changes").render() + '' + """ + def __init__(self, name, *validators, **attrs): + super(Button, self).__init__(name, *validators, **attrs) + self.description = "" + + def render(self): + attrs = self.attrs.copy() + attrs['name'] = self.name + if self.value is not None: + attrs['value'] = self.value + html = attrs.pop('html', None) or net.websafe(self.name) + return '' % (attrs, html) + +class Hidden(Input): + """Hidden Input. + + >>> Hidden(name='foo', value='bar').render() + '' + """ + def is_hidden(self): + return True + + def get_type(self): + return 'hidden' + +class File(Input): + """File input. + + >>> File(name='f').render() + '' + """ + def get_type(self): + return 'file' + +class Validator: + def __deepcopy__(self, memo): return copy.copy(self) + def __init__(self, msg, test, jstest=None): utils.autoassign(self, locals()) + def valid(self, value): + try: return self.test(value) + except: return False + +notnull = Validator("Required", bool) + +class regexp(Validator): + def __init__(self, rexp, msg): + self.rexp = re.compile(rexp) + self.msg = msg + + def valid(self, value): + return bool(self.rexp.match(value)) + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/contrib/web/http.py b/contrib/web/http.py new file mode 100644 index 0000000..cc46ceb --- /dev/null +++ b/contrib/web/http.py @@ -0,0 +1,150 @@ +""" +HTTP Utilities +(from web.py) +""" + +__all__ = [ + "expires", "lastmodified", + "prefixurl", "modified", + "changequery", "url", + "profiler", +] + +import sys, os, threading, urllib, urlparse +try: import datetime +except ImportError: pass +import net, utils, webapi as web + +def prefixurl(base=''): + """ + Sorry, this function is really difficult to explain. + Maybe some other time. + """ + url = web.ctx.path.lstrip('/') + for i in xrange(url.count('/')): + base += '../' + if not base: + base = './' + return base + +def expires(delta): + """ + Outputs an `Expires` header for `delta` from now. + `delta` is a `timedelta` object or a number of seconds. + """ + if isinstance(delta, (int, long)): + delta = datetime.timedelta(seconds=delta) + date_obj = datetime.datetime.utcnow() + delta + web.header('Expires', net.httpdate(date_obj)) + +def lastmodified(date_obj): + """Outputs a `Last-Modified` header for `datetime`.""" + web.header('Last-Modified', net.httpdate(date_obj)) + +def modified(date=None, etag=None): + """ + Checks to see if the page has been modified since the version in the + requester's cache. + + When you publish pages, you can include `Last-Modified` and `ETag` + with the date the page was last modified and an opaque token for + the particular version, respectively. When readers reload the page, + the browser sends along the modification date and etag value for + the version it has in its cache. If the page hasn't changed, + the server can just return `304 Not Modified` and not have to + send the whole page again. + + This function takes the last-modified date `date` and the ETag `etag` + and checks the headers to see if they match. If they do, it returns + `True`, or otherwise it raises NotModified error. It also sets + `Last-Modified` and `ETag` output headers. + """ + try: + from __builtin__ import set + except ImportError: + # for python 2.3 + from sets import Set as set + + n = set([x.strip('" ') for x in web.ctx.env.get('HTTP_IF_NONE_MATCH', '').split(',')]) + m = net.parsehttpdate(web.ctx.env.get('HTTP_IF_MODIFIED_SINCE', '').split(';')[0]) + validate = False + if etag: + if '*' in n or etag in n: + validate = True + if date and m: + # we subtract a second because + # HTTP dates don't have sub-second precision + if date-datetime.timedelta(seconds=1) <= m: + validate = True + + if date: lastmodified(date) + if etag: web.header('ETag', '"' + etag + '"') + if validate: + raise web.notmodified() + else: + return True + +def urlencode(query, doseq=0): + """ + Same as urllib.urlencode, but supports unicode strings. + + >>> urlencode({'text':'foo bar'}) + 'text=foo+bar' + >>> urlencode({'x': [1, 2]}, doseq=True) + 'x=1&x=2' + """ + def convert(value, doseq=False): + if doseq and isinstance(value, list): + return [convert(v) for v in value] + else: + return utils.utf8(value) + + query = dict([(k, convert(v, doseq)) for k, v in query.items()]) + return urllib.urlencode(query, doseq=doseq) + +def changequery(query=None, **kw): + """ + Imagine you're at `/foo?a=1&b=2`. Then `changequery(a=3)` will return + `/foo?a=3&b=2` -- the same URL but with the arguments you requested + changed. + """ + if query is None: + query = web.rawinput(method='get') + for k, v in kw.iteritems(): + if v is None: + query.pop(k, None) + else: + query[k] = v + out = web.ctx.path + if query: + out += '?' + urlencode(query, doseq=True) + return out + +def url(path=None, **kw): + """ + Makes url by concatinating web.ctx.homepath and path and the + query string created using the arguments. + """ + if path is None: + path = web.ctx.path + if path.startswith("/"): + out = web.ctx.homepath + path + else: + out = path + + if kw: + out += '?' + urlencode(kw) + + return out + +def profiler(app): + """Outputs basic profiling information at the bottom of each response.""" + from utils import profile + def profile_internal(e, o): + out, result = profile(app)(e, o) + return list(out) + ['
' + net.websafe(result) + '
'] + return profile_internal + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/contrib/web/httpserver.py b/contrib/web/httpserver.py new file mode 100644 index 0000000..f2c91cc --- /dev/null +++ b/contrib/web/httpserver.py @@ -0,0 +1,250 @@ +__all__ = ["runsimple"] + +import sys, os +from SimpleHTTPServer import SimpleHTTPRequestHandler + +import webapi as web +import net +import utils + +def runbasic(func, server_address=("0.0.0.0", 8080)): + """ + Runs a simple HTTP server hosting WSGI app `func`. The directory `static/` + is hosted statically. + + Based on [WsgiServer][ws] from [Colin Stewart][cs]. + + [ws]: http://www.owlfish.com/software/wsgiutils/documentation/wsgi-server-api.html + [cs]: http://www.owlfish.com/ + """ + # Copyright (c) 2004 Colin Stewart (http://www.owlfish.com/) + # Modified somewhat for simplicity + # Used under the modified BSD license: + # http://www.xfree86.org/3.3.6/COPYRIGHT2.html#5 + + import SimpleHTTPServer, SocketServer, BaseHTTPServer, urlparse + import socket, errno + import traceback + + class WSGIHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): + def run_wsgi_app(self): + protocol, host, path, parameters, query, fragment = \ + urlparse.urlparse('http://dummyhost%s' % self.path) + + # we only use path, query + env = {'wsgi.version': (1, 0) + ,'wsgi.url_scheme': 'http' + ,'wsgi.input': self.rfile + ,'wsgi.errors': sys.stderr + ,'wsgi.multithread': 1 + ,'wsgi.multiprocess': 0 + ,'wsgi.run_once': 0 + ,'REQUEST_METHOD': self.command + ,'REQUEST_URI': self.path + ,'PATH_INFO': path + ,'QUERY_STRING': query + ,'CONTENT_TYPE': self.headers.get('Content-Type', '') + ,'CONTENT_LENGTH': self.headers.get('Content-Length', '') + ,'REMOTE_ADDR': self.client_address[0] + ,'SERVER_NAME': self.server.server_address[0] + ,'SERVER_PORT': str(self.server.server_address[1]) + ,'SERVER_PROTOCOL': self.request_version + } + + for http_header, http_value in self.headers.items(): + env ['HTTP_%s' % http_header.replace('-', '_').upper()] = \ + http_value + + # Setup the state + self.wsgi_sent_headers = 0 + self.wsgi_headers = [] + + try: + # We have there environment, now invoke the application + result = self.server.app(env, self.wsgi_start_response) + try: + try: + for data in result: + if data: + self.wsgi_write_data(data) + finally: + if hasattr(result, 'close'): + result.close() + except socket.error, socket_err: + # Catch common network errors and suppress them + if (socket_err.args[0] in \ + (errno.ECONNABORTED, errno.EPIPE)): + return + except socket.timeout, socket_timeout: + return + except: + print >> web.debug, traceback.format_exc(), + + if (not self.wsgi_sent_headers): + # We must write out something! + self.wsgi_write_data(" ") + return + + do_POST = run_wsgi_app + do_PUT = run_wsgi_app + do_DELETE = run_wsgi_app + + def do_GET(self): + if self.path.startswith('/static/'): + SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self) + else: + self.run_wsgi_app() + + def wsgi_start_response(self, response_status, response_headers, + exc_info=None): + if (self.wsgi_sent_headers): + raise Exception \ + ("Headers already sent and start_response called again!") + # Should really take a copy to avoid changes in the application.... + self.wsgi_headers = (response_status, response_headers) + return self.wsgi_write_data + + def wsgi_write_data(self, data): + if (not self.wsgi_sent_headers): + status, headers = self.wsgi_headers + # Need to send header prior to data + status_code = status[:status.find(' ')] + status_msg = status[status.find(' ') + 1:] + self.send_response(int(status_code), status_msg) + for header, value in headers: + self.send_header(header, value) + self.end_headers() + self.wsgi_sent_headers = 1 + # Send the data + self.wfile.write(data) + + class WSGIServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): + def __init__(self, func, server_address): + BaseHTTPServer.HTTPServer.__init__(self, + server_address, + WSGIHandler) + self.app = func + self.serverShuttingDown = 0 + + print "http://%s:%d/" % server_address + WSGIServer(func, server_address).serve_forever() + +def runsimple(func, server_address=("0.0.0.0", 8080)): + """ + Runs [CherryPy][cp] WSGI server hosting WSGI app `func`. + The directory `static/` is hosted statically. + + [cp]: http://www.cherrypy.org + """ + func = StaticMiddleware(func) + func = LogMiddleware(func) + + server = WSGIServer(server_address, func) + + print "http://%s:%d/" % server_address + try: + server.start() + except KeyboardInterrupt: + server.stop() + +def WSGIServer(server_address, wsgi_app): + """Creates CherryPy WSGI server listening at `server_address` to serve `wsgi_app`. + This function can be overwritten to customize the webserver or use a different webserver. + """ + from wsgiserver import CherryPyWSGIServer + return CherryPyWSGIServer(server_address, wsgi_app, server_name="localhost") + +class StaticApp(SimpleHTTPRequestHandler): + """WSGI application for serving static files.""" + def __init__(self, environ, start_response): + self.headers = [] + self.environ = environ + self.start_response = start_response + + def send_response(self, status, msg=""): + self.status = str(status) + " " + msg + + def send_header(self, name, value): + self.headers.append((name, value)) + + def end_headers(self): + pass + + def log_message(*a): pass + + def __iter__(self): + environ = self.environ + + self.path = environ.get('PATH_INFO', '') + self.client_address = environ.get('REMOTE_ADDR','-'), \ + environ.get('REMOTE_PORT','-') + self.command = environ.get('REQUEST_METHOD', '-') + + from cStringIO import StringIO + self.wfile = StringIO() # for capturing error + + f = self.send_head() + self.start_response(self.status, self.headers) + + if f: + block_size = 16 * 1024 + while True: + buf = f.read(block_size) + if not buf: + break + yield buf + f.close() + else: + value = self.wfile.getvalue() + yield value + +class StaticMiddleware: + """WSGI middleware for serving static files.""" + def __init__(self, app, prefix='/static/'): + self.app = app + self.prefix = prefix + + def __call__(self, environ, start_response): + path = environ.get('PATH_INFO', '') + if path.startswith(self.prefix): + return StaticApp(environ, start_response) + else: + return self.app(environ, start_response) + +class LogMiddleware: + """WSGI middleware for logging the status.""" + def __init__(self, app): + self.app = app + self.format = '%s - - [%s] "%s %s %s" - %s' + + from BaseHTTPServer import BaseHTTPRequestHandler + import StringIO + f = StringIO.StringIO() + + class FakeSocket: + def makefile(self, *a): + return f + + # take log_date_time_string method from BaseHTTPRequestHandler + self.log_date_time_string = BaseHTTPRequestHandler(FakeSocket(), None, None).log_date_time_string + + def __call__(self, environ, start_response): + def xstart_response(status, response_headers, *args): + out = start_response(status, response_headers, *args) + self.log(status, environ) + return out + + return self.app(environ, xstart_response) + + def log(self, status, environ): + outfile = environ.get('wsgi.errors', web.debug) + req = environ.get('PATH_INFO', '_') + protocol = environ.get('ACTUAL_SERVER_PROTOCOL', '-') + method = environ.get('REQUEST_METHOD', '-') + host = "%s:%s" % (environ.get('REMOTE_ADDR','-'), + environ.get('REMOTE_PORT','-')) + + time = self.log_date_time_string() + + msg = self.format % (host, time, protocol, method, req, status) + print >> outfile, utils.safestr(msg) diff --git a/contrib/web/net.py b/contrib/web/net.py new file mode 100644 index 0000000..6c3ee85 --- /dev/null +++ b/contrib/web/net.py @@ -0,0 +1,190 @@ +""" +Network Utilities +(from web.py) +""" + +__all__ = [ + "validipaddr", "validipport", "validip", "validaddr", + "urlquote", + "httpdate", "parsehttpdate", + "htmlquote", "htmlunquote", "websafe", +] + +import urllib, time +try: import datetime +except ImportError: pass + +def validipaddr(address): + """ + Returns True if `address` is a valid IPv4 address. + + >>> validipaddr('192.168.1.1') + True + >>> validipaddr('192.168.1.800') + False + >>> validipaddr('192.168.1') + False + """ + try: + octets = address.split('.') + if len(octets) != 4: + return False + for x in octets: + if not (0 <= int(x) <= 255): + return False + except ValueError: + return False + return True + +def validipport(port): + """ + Returns True if `port` is a valid IPv4 port. + + >>> validipport('9000') + True + >>> validipport('foo') + False + >>> validipport('1000000') + False + """ + try: + if not (0 <= int(port) <= 65535): + return False + except ValueError: + return False + return True + +def validip(ip, defaultaddr="0.0.0.0", defaultport=8080): + """Returns `(ip_address, port)` from string `ip_addr_port`""" + addr = defaultaddr + port = defaultport + + ip = ip.split(":", 1) + if len(ip) == 1: + if not ip[0]: + pass + elif validipaddr(ip[0]): + addr = ip[0] + elif validipport(ip[0]): + port = int(ip[0]) + else: + raise ValueError, ':'.join(ip) + ' is not a valid IP address/port' + elif len(ip) == 2: + addr, port = ip + if not validipaddr(addr) and validipport(port): + raise ValueError, ':'.join(ip) + ' is not a valid IP address/port' + port = int(port) + else: + raise ValueError, ':'.join(ip) + ' is not a valid IP address/port' + return (addr, port) + +def validaddr(string_): + """ + Returns either (ip_address, port) or "/path/to/socket" from string_ + + >>> validaddr('/path/to/socket') + '/path/to/socket' + >>> validaddr('8000') + ('0.0.0.0', 8000) + >>> validaddr('127.0.0.1') + ('127.0.0.1', 8080) + >>> validaddr('127.0.0.1:8000') + ('127.0.0.1', 8000) + >>> validaddr('fff') + Traceback (most recent call last): + ... + ValueError: fff is not a valid IP address/port + """ + if '/' in string_: + return string_ + else: + return validip(string_) + +def urlquote(val): + """ + Quotes a string for use in a URL. + + >>> urlquote('://?f=1&j=1') + '%3A//%3Ff%3D1%26j%3D1' + >>> urlquote(None) + '' + >>> urlquote(u'\u203d') + '%E2%80%BD' + """ + if val is None: return '' + if not isinstance(val, unicode): val = str(val) + else: val = val.encode('utf-8') + return urllib.quote(val) + +def httpdate(date_obj): + """ + Formats a datetime object for use in HTTP headers. + + >>> import datetime + >>> httpdate(datetime.datetime(1970, 1, 1, 1, 1, 1)) + 'Thu, 01 Jan 1970 01:01:01 GMT' + """ + return date_obj.strftime("%a, %d %b %Y %H:%M:%S GMT") + +def parsehttpdate(string_): + """ + Parses an HTTP date into a datetime object. + + >>> parsehttpdate('Thu, 01 Jan 1970 01:01:01 GMT') + datetime.datetime(1970, 1, 1, 1, 1, 1) + """ + try: + t = time.strptime(string_, "%a, %d %b %Y %H:%M:%S %Z") + except ValueError: + return None + return datetime.datetime(*t[:6]) + +def htmlquote(text): + """ + Encodes `text` for raw use in HTML. + + >>> htmlquote("<'&\\">") + '<'&">' + """ + text = text.replace("&", "&") # Must be done first! + text = text.replace("<", "<") + text = text.replace(">", ">") + text = text.replace("'", "'") + text = text.replace('"', """) + return text + +def htmlunquote(text): + """ + Decodes `text` that's HTML quoted. + + >>> htmlunquote('<'&">') + '<\\'&">' + """ + text = text.replace(""", '"') + text = text.replace("'", "'") + text = text.replace(">", ">") + text = text.replace("<", "<") + text = text.replace("&", "&") # Must be done last! + return text + +def websafe(val): + """ + Converts `val` so that it's safe for use in UTF-8 HTML. + + >>> websafe("<'&\\">") + '<'&">' + >>> websafe(None) + '' + >>> websafe(u'\u203d') + '\\xe2\\x80\\xbd' + """ + if val is None: + return '' + if isinstance(val, unicode): + val = val.encode('utf-8') + val = str(val) + return htmlquote(val) + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/contrib/web/session.py b/contrib/web/session.py new file mode 100644 index 0000000..a260318 --- /dev/null +++ b/contrib/web/session.py @@ -0,0 +1,319 @@ +""" +Session Management +(from web.py) +""" + +import os, time, datetime, random, base64 +try: + import cPickle as pickle +except ImportError: + import pickle +try: + import hashlib + sha1 = hashlib.sha1 +except ImportError: + import sha + sha1 = sha.new + +import utils +import webapi as web + +__all__ = [ + 'Session', 'SessionExpired', + 'Store', 'DiskStore', 'DBStore', +] + +web.config.session_parameters = utils.storage({ + 'cookie_name': 'webpy_session_id', + 'cookie_domain': None, + 'timeout': 86400, #24 * 60 * 60, # 24 hours in seconds + 'ignore_expiry': True, + 'ignore_change_ip': True, + 'secret_key': 'fLjUfxqXtfNoIldA0A0J', + 'expired_message': 'Session expired', +}) + +class SessionExpired(web.HTTPError): + def __init__(self, message): + web.HTTPError.__init__(self, '200 OK', {}, data=message) + +class Session(utils.ThreadedDict): + """Session management for web.py + """ + + def __init__(self, app, store, initializer=None): + self.__dict__['store'] = store + self.__dict__['_initializer'] = initializer + self.__dict__['_last_cleanup_time'] = 0 + self.__dict__['_config'] = utils.storage(web.config.session_parameters) + + if app: + app.add_processor(self._processor) + + def _processor(self, handler): + """Application processor to setup session for every request""" + self._cleanup() + self._load() + + try: + return handler() + finally: + self._save() + + def _load(self): + """Load the session from the store, by the id from cookie""" + cookie_name = self._config.cookie_name + cookie_domain = self._config.cookie_domain + self.session_id = web.cookies().get(cookie_name) + + # protection against session_id tampering + if self.session_id and not self._valid_session_id(self.session_id): + self.session_id = None + + self._check_expiry() + if self.session_id: + d = self.store[self.session_id] + self.update(d) + self._validate_ip() + + if not self.session_id: + self.session_id = self._generate_session_id() + + if self._initializer: + if isinstance(self._initializer, dict): + self.update(self._initializer) + elif hasattr(self._initializer, '__call__'): + self._initializer() + + self.ip = web.ctx.ip + + def _check_expiry(self): + # check for expiry + if self.session_id and self.session_id not in self.store: + if self._config.ignore_expiry: + self.session_id = None + else: + return self.expired() + + def _validate_ip(self): + # check for change of IP + if self.session_id and self.get('ip', None) != web.ctx.ip: + if not self._config.ignore_change_ip: + return self.expired() + + def _save(self): + cookie_name = self._config.cookie_name + cookie_domain = self._config.cookie_domain + if not self.get('_killed'): + web.setcookie(cookie_name, self.session_id, domain=cookie_domain) + self.store[self.session_id] = dict(self) + else: + web.setcookie(cookie_name, self.session_id, expires=-1, domain=cookie_domain) + + def _generate_session_id(self): + """Generate a random id for session""" + + while True: + rand = os.urandom(16) + now = time.time() + secret_key = self._config.secret_key + session_id = sha1("%s%s%s%s" %(rand, now, utils.safestr(web.ctx.ip), secret_key)) + session_id = session_id.hexdigest() + if session_id not in self.store: + break + return session_id + + def _valid_session_id(self, session_id): + rx = utils.re_compile('^[0-9a-fA-F]+$') + return rx.match(session_id) + + def _cleanup(self): + """Cleanup the stored sessions""" + current_time = time.time() + timeout = self._config.timeout + if current_time - self._last_cleanup_time > timeout: + self.store.cleanup(timeout) + self.__dict__['_last_cleanup_time'] = current_time + + def expired(self): + """Called when an expired session is atime""" + self._killed = True + self._save() + raise SessionExpired(self._config.expired_message) + + def kill(self): + """Kill the session, make it no longer available""" + del self.store[self.session_id] + self._killed = True + +class Store: + """Base class for session stores""" + + def __contains__(self, key): + raise NotImplementedError + + def __getitem__(self, key): + raise NotImplementedError + + def __setitem__(self, key, value): + raise NotImplementedError + + def cleanup(self, timeout): + """removes all the expired sessions""" + raise NotImplementedError + + def encode(self, session_dict): + """encodes session dict as a string""" + pickled = pickle.dumps(session_dict) + return base64.encodestring(pickled) + + def decode(self, session_data): + """decodes the data to get back the session dict """ + pickled = base64.decodestring(session_data) + return pickle.loads(pickled) + +class DiskStore(Store): + """ + Store for saving a session on disk. + + >>> import tempfile + >>> root = tempfile.mkdtemp() + >>> s = DiskStore(root) + >>> s['a'] = 'foo' + >>> s['a'] + 'foo' + >>> time.sleep(0.01) + >>> s.cleanup(0.01) + >>> s['a'] + Traceback (most recent call last): + ... + KeyError: 'a' + """ + def __init__(self, root): + # if the storage root doesn't exists, create it. + if not os.path.exists(root): + os.mkdir(root) + self.root = root + + def _get_path(self, key): + if os.path.sep in key: + raise ValueError, "Bad key: %s" % repr(key) + return os.path.join(self.root, key) + + def __contains__(self, key): + path = self._get_path(key) + return os.path.exists(path) + + def __getitem__(self, key): + path = self._get_path(key) + if os.path.exists(path): + pickled = open(path).read() + return self.decode(pickled) + else: + raise KeyError, key + + def __setitem__(self, key, value): + path = self._get_path(key) + pickled = self.encode(value) + try: + f = open(path, 'w') + try: + f.write(pickled) + finally: + f.close() + except IOError: + pass + + def __delitem__(self, key): + path = self._get_path(key) + if os.path.exists(path): + os.remove(path) + + def cleanup(self, timeout): + now = time.time() + for f in os.listdir(self.root): + path = self._get_path(f) + atime = os.stat(path).st_atime + if now - atime > timeout : + os.remove(path) + +class DBStore(Store): + """Store for saving a session in database + Needs a table with the following columns: + + session_id CHAR(128) UNIQUE NOT NULL, + atime DATETIME NOT NULL default current_timestamp, + data TEXT + """ + def __init__(self, db, table_name): + self.db = db + self.table = table_name + + def __contains__(self, key): + data = self.db.select(self.table, where="session_id=$key", vars=locals()) + return bool(list(data)) + + def __getitem__(self, key): + now = datetime.datetime.now() + try: + s = self.db.select(self.table, where="session_id=$key", vars=locals())[0] + self.db.update(self.table, where="session_id=$key", atime=now, vars=locals()) + except IndexError: + raise KeyError + else: + return self.decode(s.data) + + def __setitem__(self, key, value): + pickled = self.encode(value) + now = datetime.datetime.now() + if key in self: + self.db.update(self.table, where="session_id=$key", data=pickled, vars=locals()) + else: + self.db.insert(self.table, False, session_id=key, data=pickled ) + + def __delitem__(self, key): + self.db.delete(self.table, where="session_id=$key", vars=locals()) + + def cleanup(self, timeout): + timeout = datetime.timedelta(timeout/(24.0*60*60)) #timedelta takes numdays as arg + last_allowed_time = datetime.datetime.now() - timeout + self.db.delete(self.table, where="$last_allowed_time > atime", vars=locals()) + +class ShelfStore: + """Store for saving session using `shelve` module. + + import shelve + store = ShelfStore(shelve.open('session.shelf')) + + XXX: is shelve thread-safe? + """ + def __init__(self, shelf): + self.shelf = shelf + + def __contains__(self, key): + return key in self.shelf + + def __getitem__(self, key): + atime, v = self.shelf[key] + self[key] = v # update atime + return v + + def __setitem__(self, key, value): + self.shelf[key] = time.time(), value + + def __delitem__(self, key): + try: + del self.shelf[key] + except KeyError: + pass + + def cleanup(self, timeout): + now = time.time() + for k in self.shelf.keys(): + atime, v = self.shelf[k] + if now - atime > timeout : + del self[k] + +if __name__ == '__main__' : + import doctest + doctest.testmod() diff --git a/contrib/web/template.py b/contrib/web/template.py new file mode 100644 index 0000000..ef87200 --- /dev/null +++ b/contrib/web/template.py @@ -0,0 +1,1471 @@ +""" +simple, elegant templating +(part of web.py) + +Template design: + +Template string is split into tokens and the tokens are combined into nodes. +Parse tree is a nodelist. TextNode and ExpressionNode are simple nodes and +for-loop, if-loop etc are block nodes, which contain multiple child nodes. + +Each node can emit some python string. python string emitted by the +root node is validated for safeeval and executed using python in the given environment. + +Enough care is taken to make sure the generated code and the template has line to line match, +so that the error messages can point to exact line number in template. (It doesn't work in some cases still.) + +Grammar: + + template -> defwith sections + defwith -> '$def with (' arguments ')' | '' + sections -> section* + section -> block | assignment | line + + assignment -> '$ ' + line -> (text|expr)* + text -> + expr -> '$' pyexpr | '$(' pyexpr ')' | '${' pyexpr '}' + pyexpr -> +""" + +__all__ = [ + "Template", + "Render", "render", "frender", + "ParseError", "SecurityError", + "test" +] + +import tokenize +import os +import glob +import re + +from utils import storage, safeunicode, safestr, re_compile +from webapi import config +from net import websafe + +def splitline(text): + r""" + Splits the given text at newline. + + >>> splitline('foo\nbar') + ('foo\n', 'bar') + >>> splitline('foo') + ('foo', '') + >>> splitline('') + ('', '') + """ + index = text.find('\n') + 1 + if index: + return text[:index], text[index:] + else: + return text, '' + +class Parser: + """Parser Base. + """ + def __init__(self): + self.statement_nodes = STATEMENT_NODES + self.keywords = KEYWORDS + + def parse(self, text, name="