diff --git a/.pylintrc b/.pylintrc index d1c6c053..ba943f0a 100644 --- a/.pylintrc +++ b/.pylintrc @@ -20,3 +20,8 @@ disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-d [REPORTS] msg-template={path}:{line}: {msg} ({symbol}) reports=no + +[TYPECHECK] +# AST classes have dynamic members. Writer does not but for some reason pylint +# barfs on some of its members. +ignored-classes=pythonparser.ast.Module,grumpy.compiler.util.Writer diff --git a/.travis.yml b/.travis.yml index 7e6e53fa..7f0c7112 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,6 +2,10 @@ language: go os: - linux - osx +before_script: + - rvm get head || true # https://github.com/travis-ci/travis-ci/issues/6307 + - set -e # Run gofmt and lint serially to avoid confusing output. Run tests in parallel # for speed. script: make gofmt lint && make -j2 test +after_script: set +e diff --git a/AUTHORS.md b/AUTHORS.md index 60f1ab4b..b816b171 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -24,3 +24,4 @@ Contributors in the order of first contribution * [wuttem](https://github.com/wuttem) * [cclauss](https://github.com/cclauss) * [Mirko Dziadzka](https://github.com/MirkoDziadzka) +* [Dong-hee Na](https://github.com/corona10) diff --git a/Makefile b/Makefile index cb2ea761..9398f9e7 100644 --- a/Makefile +++ b/Makefile @@ -32,11 +32,20 @@ ifeq ($(PYTHON),) endif PYTHON_BIN := $(shell which $(PYTHON)) PYTHON_VER := $(word 2,$(shell $(PYTHON) -V 2>&1)) +GO_REQ_MAJ := 1 +GO_REQ_MIN := 9 +GO_MAJ_MIN := $(subst go,, $(word 3,$(shell go version 2>&1)) ) +GO_MAJ := $(word 1,$(subst ., ,$(GO_MAJ_MIN) )) +GO_MIN := $(word 2,$(subst ., ,$(GO_MAJ_MIN) )) ifeq ($(filter 2.7.%,$(PYTHON_VER)),) $(error unsupported Python version $(PYTHON_VER), Grumpy only supports 2.7.x. To use a different python binary such as python2, run: 'make PYTHON=python2 ...') endif +ifneq ($(shell test $(GO_MAJ) -ge $(GO_REQ_MAJ) -a $(GO_MIN) -ge $(GO_REQ_MIN) && echo ok),ok) + $(error unsupported Go version $(GO_VER), Grumpy requires at least $(GO_REQ_MAJ).$(GO_REQ_MIN). Please update Go) +endif + PY_DIR := build/lib/python2.7/site-packages PY_INSTALL_DIR := $(shell $(PYTHON) -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") @@ -44,6 +53,10 @@ export GOPATH := $(ROOT_DIR)/build export PYTHONPATH := $(ROOT_DIR)/$(PY_DIR) export PATH := $(ROOT_DIR)/build/bin:$(PATH) +GOPATH_PY_ROOT := $(GOPATH)/src/__python__ + +PYTHONPARSER_SRCS := $(patsubst third_party/%,$(PY_DIR)/grumpy/%,$(wildcard third_party/pythonparser/*.py)) + COMPILER_BIN := build/bin/grumpc COMPILER_SRCS := $(addprefix $(PY_DIR)/grumpy/compiler/,$(notdir $(shell find compiler -name '*.py' -not -name '*_test.py'))) $(PY_DIR)/grumpy/__init__.py COMPILER_TESTS := $(patsubst %.py,grumpy/%,$(filter-out compiler/expr_visitor_test.py compiler/stmt_test.py,$(wildcard compiler/*_test.py))) @@ -53,7 +66,9 @@ COMPILER_PASS_FILES := $(patsubst %,$(PY_DIR)/%.pass,$(COMPILER_TESTS)) COMPILER_EXPR_VISITOR_PASS_FILES := $(patsubst %,$(PY_DIR)/grumpy/compiler/expr_visitor_test.%of32.pass,$(shell seq 32)) COMPILER_STMT_PASS_FILES := $(patsubst %,$(PY_DIR)/grumpy/compiler/stmt_test.%of16.pass,$(shell seq 16)) COMPILER_D_FILES := $(patsubst %,$(PY_DIR)/%.d,$(COMPILER_TESTS)) -COMPILER := $(COMPILER_BIN) $(COMPILER_SRCS) +COMPILER := $(COMPILER_BIN) $(COMPILER_SRCS) $(PYTHONPARSER_SRCS) + +PKGC_BIN := build/bin/pkgc RUNNER_BIN := build/bin/grumprun RUNTIME_SRCS := $(addprefix build/src/grumpy/,$(notdir $(wildcard runtime/*.go))) @@ -62,13 +77,14 @@ RUNTIME_PASS_FILE := build/runtime.pass RUNTIME_COVER_FILE := $(PKG_DIR)/grumpy.cover RUNNER = $(RUNNER_BIN) $(COMPILER) $(RUNTIME) $(STDLIB) -GRUMPY_STDLIB_SRCS := $(shell find lib -name '*.py') -GRUMPY_STDLIB_PACKAGES := $(foreach x,$(GRUMPY_STDLIB_SRCS),$(patsubst lib/%.py,%,$(patsubst lib/%/__init__.py,%,$(x)))) -THIRD_PARTY_STDLIB_SRCS := $(shell find third_party -name '*.py') -THIRD_PARTY_STDLIB_PACKAGES := $(foreach x,$(THIRD_PARTY_STDLIB_SRCS),$(patsubst third_party/stdlib/%.py,%,$(patsubst third_party/pypy/%.py,%,$(patsubst third_party/pypy/%/__init__.py,%,$(patsubst third_party/stdlib/%/__init__.py,%,$(x)))))) -STDLIB_SRCS := $(GRUMPY_STDLIB_SRCS) $(THIRD_PARTY_STDLIB_SRCS) -STDLIB_PACKAGES := $(GRUMPY_STDLIB_PACKAGES) $(THIRD_PARTY_STDLIB_PACKAGES) -STDLIB := $(patsubst %,$(PKG_DIR)/grumpy/lib/%.a,$(STDLIB_PACKAGES)) +LIB_SRCS := $(patsubst lib/%,$(GOPATH_PY_ROOT)/%,$(shell find lib -name '*.py')) +THIRD_PARTY_STDLIB_SRCS := $(patsubst third_party/stdlib/%,$(GOPATH_PY_ROOT)/%,$(shell find third_party/stdlib -name '*.py')) +THIRD_PARTY_PYPY_SRCS := $(patsubst third_party/pypy/%,$(GOPATH_PY_ROOT)/%,$(shell find third_party/pypy -name '*.py')) +THIRD_PARTY_OUROBOROS_SRCS := $(patsubst third_party/ouroboros/%,$(GOPATH_PY_ROOT)/%,$(shell find third_party/ouroboros -name '*.py')) +STDLIB_SRCS := $(LIB_SRCS) $(THIRD_PARTY_STDLIB_SRCS) $(THIRD_PARTY_PYPY_SRCS) $(THIRD_PARTY_OUROBOROS_SRCS) + +STDLIB_PACKAGES := $(patsubst $(GOPATH_PY_ROOT)/%.py,%,$(patsubst $(GOPATH_PY_ROOT)/%/__init__.py,%,$(STDLIB_SRCS))) +STDLIB := $(patsubst %,$(PKG_DIR)/__python__/%.a,$(STDLIB_PACKAGES)) STDLIB_TESTS := \ itertools_test \ math_test \ @@ -78,12 +94,30 @@ STDLIB_TESTS := \ re_tests \ sys_test \ tempfile_test \ - test/test_tuple \ + test/test_bisect \ + test/test_colorsys \ + test/test_datetime \ test/test_dict \ + test/test_dircache \ + test/test_dummy_thread \ + test/test_fpformat \ + test/test_genericpath \ test/test_list \ + test/test_md5 \ + test/test_mimetools \ + test/test_mutex \ + test/test_operator \ + test/test_quopri \ + test/test_queue \ + test/test_rfc822 \ + test/test_sched \ + test/test_select \ test/test_slice \ + test/test_stat \ test/test_string \ - threading_test \ + test/test_threading \ + test/test_tuple \ + test/test_uu \ time_test \ types_test \ weetest_test @@ -96,12 +130,12 @@ ACCEPT_PY_PASS_FILES := $(patsubst %,build/%_py.pass,$(filter-out %/native_test, BENCHMARKS := $(patsubst %.py,%,$(wildcard benchmarks/*.py)) BENCHMARK_BINS := $(patsubst %,build/%_benchmark,$(BENCHMARKS)) -TOOL_BINS = $(patsubst %,build/bin/%,benchcmp coverparse diffrange) +TOOL_BINS = $(patsubst %,build/bin/%,benchcmp coverparse diffrange genmake pydeps) GOLINT_BIN = build/bin/golint PYLINT_BIN = build/bin/pylint -all: $(COMPILER) $(RUNTIME) $(STDLIB) $(TOOL_BINS) +all: $(COMPILER) $(RUNNER) $(RUNTIME) $(TOOL_BINS) benchmarks: $(BENCHMARK_BINS) @@ -132,28 +166,43 @@ $(COMPILER_SRCS) $(COMPILER_TEST_SRCS) $(COMPILER_SHARDED_TEST_SRCS): $(PY_DIR)/ @mkdir -p $(PY_DIR)/grumpy/compiler @cp -f $< $@ -$(COMPILER_PASS_FILES): %.pass: %.py $(COMPILER) +$(COMPILER_PASS_FILES): %.pass: %.py $(COMPILER) $(COMPILER_TEST_SRCS) @$(PYTHON) $< -q @touch $@ @echo compiler/`basename $*` PASS -$(COMPILER_D_FILES): $(PY_DIR)/%.d: $(PY_DIR)/%.py $(COMPILER_SRCS) - @$(PYTHON) -m modulefinder $< | awk '{if (match($$2, /^grumpy\>/)) { print "$(PY_DIR)/$*.pass: " substr($$3, length("$(ROOT_DIR)/") + 1) }}' > $@ +# NOTE: In the regex below we use (\.|$) instead of \> because the latter is +# not available in the awk available on OS X. +$(COMPILER_D_FILES): $(PY_DIR)/%.d: $(PY_DIR)/%.py $(COMPILER_SRCS) $(PYTHONPARSER_SRCS) + @$(PYTHON) -m modulefinder $< | awk '{if (match($$2, /^grumpy(\.|$$)/)) { print "$(PY_DIR)/$*.pass: " substr($$3, length("$(ROOT_DIR)/") + 1) }}' > $@ -include $(COMPILER_D_FILES) # Does not depend on stdlibs since it makes minimal use of them. -$(COMPILER_EXPR_VISITOR_PASS_FILES): $(PY_DIR)/grumpy/compiler/expr_visitor_test.%.pass: $(PY_DIR)/grumpy/compiler/expr_visitor_test.py $(RUNNER_BIN) $(COMPILER) $(RUNTIME) +$(COMPILER_EXPR_VISITOR_PASS_FILES): $(PY_DIR)/grumpy/compiler/expr_visitor_test.%.pass: $(PY_DIR)/grumpy/compiler/expr_visitor_test.py $(RUNNER_BIN) $(COMPILER) $(RUNTIME) $(PKG_DIR)/__python__/traceback.a @$(PYTHON) $< --shard=$* @touch $@ @echo 'compiler/expr_visitor_test $* PASS' +COMPILER_STMT_PASS_FILE_DEPS := \ + $(PKG_DIR)/__python__/__go__/grumpy.a \ + $(PKG_DIR)/__python__/__go__/os.a \ + $(PKG_DIR)/__python__/__go__/runtime.a \ + $(PKG_DIR)/__python__/__go__/time.a \ + $(PKG_DIR)/__python__/__go__/unicode.a \ + $(PKG_DIR)/__python__/sys.a \ + $(PKG_DIR)/__python__/traceback.a + # Does not depend on stdlibs since it makes minimal use of them. -$(COMPILER_STMT_PASS_FILES): $(PY_DIR)/grumpy/compiler/stmt_test.%.pass: $(PY_DIR)/grumpy/compiler/stmt_test.py $(RUNNER_BIN) $(COMPILER) $(RUNTIME) +$(COMPILER_STMT_PASS_FILES): $(PY_DIR)/grumpy/compiler/stmt_test.%.pass: $(PY_DIR)/grumpy/compiler/stmt_test.py $(RUNNER_BIN) $(COMPILER) $(RUNTIME) $(COMPILER_STMT_PASS_FILE_DEPS) @$(PYTHON) $< --shard=$* @touch $@ @echo 'compiler/stmt_test $* PASS' +$(PKGC_BIN): tools/pkgc.go + @mkdir -p $(@D) + @go build -o $@ $< + # ------------------------------------------------------------------------------ # Grumpy runtime # ------------------------------------------------------------------------------ @@ -197,52 +246,61 @@ $(PYLINT_BIN): @cd build/third_party && curl -sL https://pypi.io/packages/source/p/pylint/pylint-1.6.4.tar.gz | tar -zx @cd build/third_party/pylint-1.6.4 && $(PYTHON) setup.py install --prefix $(ROOT_DIR)/build -pylint: $(PYLINT_BIN) - @$(PYLINT_BIN) compiler/*.py $(addprefix tools/,benchcmp coverparse diffrange grumpc grumprun) +pylint: $(PYLINT_BIN) $(COMPILER_SRCS) $(PYTHONPARSER_SRCS) $(COMPILER_BIN) $(RUNNER_BIN) $(TOOL_BINS) + @$(PYTHON) $(PYLINT_BIN) $(COMPILER_SRCS) $(COMPILER_BIN) $(RUNNER_BIN) $(TOOL_BINS) lint: golint pylint +# ------------------------------------------------------------------------------ +# Native modules +# ------------------------------------------------------------------------------ + +$(PKG_DIR)/__python__/__go__/%.a: build/src/__python__/__go__/%/module.go $(RUNTIME) + @mkdir -p $(@D) + @go install __python__/__go__/$* + +build/src/__python__/__go__/%/module.go: $(PKGC_BIN) $(RUNTIME) + @mkdir -p $(@D) + @$(PKGC_BIN) $* > $@ + +$(PKG_DIR)/__python__/__go__/grumpy.a: $(RUNTIME) + +.PRECIOUS: build/src/__python__/__go__/%/module.go $(PKG_DIR)/__python__/__go__/%.a + # ------------------------------------------------------------------------------ # Standard library # ------------------------------------------------------------------------------ -define GRUMPY_STDLIB -build/src/grumpy/lib/$(2)/module.go: $(1) $(COMPILER) - @mkdir -p build/src/grumpy/lib/$(2) - @$(COMPILER_BIN) -modname=$(notdir $(2)) $(1) > $$@ +$(LIB_SRCS): $(GOPATH_PY_ROOT)/%: lib/% + @mkdir -p $(@D) + @cp -f $< $@ -build/src/grumpy/lib/$(2)/module.d: $(1) - @mkdir -p build/src/grumpy/lib/$(2) - @$(PYTHON) -m modulefinder -p $(ROOT_DIR)/lib:$(ROOT_DIR)/third_party/stdlib:$(ROOT_DIR)/third_party/pypy $$< | awk '{if (($$$$1 == "m" || $$$$1 == "P") && $$$$2 != "__main__" && $$$$2 != "$(2)") {gsub(/\./, "/", $$$$2); print "$(PKG_DIR)/grumpy/lib/$(2).a: $(PKG_DIR)/grumpy/lib/" $$$$2 ".a"}}' > $$@ +$(THIRD_PARTY_STDLIB_SRCS): $(GOPATH_PY_ROOT)/%: third_party/stdlib/% + @mkdir -p $(@D) + @cp -f $< $@ -$(PKG_DIR)/grumpy/lib/$(2).a: build/src/grumpy/lib/$(2)/module.go $(RUNTIME) - @mkdir -p $(PKG_DIR)/grumpy/lib/$(dir $(2)) - @go tool compile -o $$@ -p grumpy/lib/$(2) -complete -I $(PKG_DIR) -pack $$< +$(THIRD_PARTY_PYPY_SRCS): $(GOPATH_PY_ROOT)/%: third_party/pypy/% + @mkdir -p $(@D) + @cp -f $< $@ --include build/src/grumpy/lib/$(2)/module.d -endef +$(THIRD_PARTY_OUROBOROS_SRCS): $(GOPATH_PY_ROOT)/%: third_party/ouroboros/% + @mkdir -p $(@D) + @cp -f $< $@ -$(eval $(foreach x,$(shell seq $(words $(STDLIB_SRCS))),$(call GRUMPY_STDLIB,$(word $(x),$(STDLIB_SRCS)),$(word $(x),$(STDLIB_PACKAGES))))) +build/stdlib.mk: build/bin/genmake | $(STDLIB_SRCS) + @genmake build > $@ + +-include build/stdlib.mk + +$(patsubst %,build/src/__python__/%/module.go,$(STDLIB_PACKAGES)): $(COMPILER) +$(patsubst %,build/src/__python__/%/module.d,$(STDLIB_PACKAGES)): build/bin/pydeps $(PYTHONPARSER_SRCS) $(COMPILER) +$(patsubst %,$(PKG_DIR)/__python__/%.a,$(STDLIB_PACKAGES)): $(RUNTIME) define GRUMPY_STDLIB_TEST -build/testing/$(patsubst %_test,%_test_,$(notdir $(1))).go: - @mkdir -p build/testing - @echo 'package main' > $$@ - @echo 'import (' >> $$@ - @echo ' "os"' >> $$@ - @echo ' "grumpy"' >> $$@ - @echo ' mod "grumpy/lib/$(1)"' >> $$@ - @echo ')' >> $$@ - @echo 'func main() {' >> $$@ - @echo ' os.Exit(grumpy.RunMain(mod.Code))' >> $$@ - @echo '}' >> $$@ - -build/testing/$(notdir $(1)): build/testing/$(patsubst %_test,%_test_,$(notdir $(1))).go $(RUNTIME) $(PKG_DIR)/grumpy/lib/$(1).a - @go build -o $$@ $$< - -build/testing/$(notdir $(1)).pass: build/testing/$(notdir $(1)) - @$$< +build/testing/$(notdir $(1)).pass: $(RUNTIME) $(PKG_DIR)/__python__/$(1).a $(RUNNER_BIN) $(PKG_DIR)/__python__/traceback.a + @mkdir -p $$(@D) + @$(RUNNER_BIN) -m $(subst /,.,$(1)) @touch $$@ @echo 'lib/$(1) PASS' @@ -257,21 +315,26 @@ $(eval $(foreach x,$(STDLIB_TESTS),$(call GRUMPY_STDLIB_TEST,$(x)))) $(PY_DIR)/weetest.py: lib/weetest.py @cp -f $< $@ -$(patsubst %_test,build/%.go,$(ACCEPT_TESTS)): build/%.go: %_test.py $(COMPILER) +$(PYTHONPARSER_SRCS): $(PY_DIR)/grumpy/%: third_party/% @mkdir -p $(@D) - @$(COMPILER_BIN) $< > $@ + @cp -f $< $@ -# TODO: These should not depend on stdlibs and should instead build a .d file. -$(patsubst %,build/%,$(ACCEPT_TESTS)): build/%_test: build/%.go $(RUNTIME) $(STDLIB) +$(ACCEPT_PASS_FILES): build/%_test.pass: %_test.py $(RUNTIME) $(STDLIB) $(RUNNER_BIN) @mkdir -p $(@D) - @go build -o $@ $< - -$(ACCEPT_PASS_FILES): build/%_test.pass: build/%_test - @$< + @$(RUNNER_BIN) < $< @touch $@ @echo '$*_test PASS' +NATIVE_TEST_DEPS := \ + $(PKG_DIR)/__python__/__go__/encoding/csv.a \ + $(PKG_DIR)/__python__/__go__/image.a \ + $(PKG_DIR)/__python__/__go__/math.a \ + $(PKG_DIR)/__python__/__go__/strings.a + +build/testing/native_test.pass: $(NATIVE_TEST_DEPS) + $(ACCEPT_PY_PASS_FILES): build/%_py.pass: %.py $(PY_DIR)/weetest.py + @mkdir -p $(@D) @$(PYTHON) $< @touch $@ @echo '$*_py PASS' @@ -290,11 +353,12 @@ $(BENCHMARK_BINS): build/benchmarks/%_benchmark: build/benchmarks/%.go $(RUNTIME install: $(RUNNER_BIN) $(COMPILER) $(RUNTIME) $(STDLIB) # Binary executables - install -Dm755 build/bin/grumpc "$(DESTDIR)/usr/bin/grumpc" - install -Dm755 build/bin/grumprun "$(DESTDIR)/usr/bin/grumprun" + install -d "$(DESTDIR)/usr/bin" + install -m755 build/bin/grumpc "$(DESTDIR)/usr/bin/grumpc" + install -m755 build/bin/grumprun "$(DESTDIR)/usr/bin/grumprun" # Python module install -d "$(DESTDIR)"{/usr/lib/go,"$(PY_INSTALL_DIR)"} - cp -rv --no-preserve=ownership "$(PY_DIR)/grumpy" "$(DESTDIR)$(PY_INSTALL_DIR)" + cp -rv "$(PY_DIR)/grumpy" "$(DESTDIR)$(PY_INSTALL_DIR)" # Go package and sources - cp -rv --no-preserve=ownership build/pkg build/src "$(DESTDIR)/usr/lib/go/" - + install -d "$(DESTDIR)/usr/lib/go/" + cp -rv build/pkg build/src "$(DESTDIR)/usr/lib/go/" diff --git a/README.md b/README.md index 23735166..13ec575d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Grumpy: Go running Python [![Build Status](https://travis-ci.org/google/grumpy.svg?branch=master)](https://travis-ci.org/google/grumpy) +[![Join the chat at https://gitter.im/grumpy-devel/Lobby](https://badges.gitter.im/grumpy-devel/Lobby.svg)](https://gitter.im/grumpy-devel/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) ## Overview @@ -9,8 +10,8 @@ to be a near drop-in replacement for CPython 2.7. The key difference is that it compiles Python source code to Go source code which is then compiled to native code, rather than to bytecode. This means that Grumpy has no VM. The compiled Go source code is a series of calls to the Grumpy runtime, a Go library serving a -similar purpose to the Python C API (although the C API is not directly -supported). +similar purpose to the Python C API (although the API is incompatible with +CPython's). ## Limitations @@ -33,29 +34,29 @@ supported). There are three basic categories of incomplete functionality: -1. Language features: Most language features are implemented with the notable - exception of decorators. There are also a handful of operators that aren't - yet supported. +1. [Language features](https://github.com/google/grumpy/wiki/Missing-features#language-features): + Most language features are implemented with the notable exception of + [old-style classes](http://stackoverflow.com/questions/54867/what-is-the-difference-between-old-style-and-new-style-classes-in-python). + There are also a handful of operators that aren't yet supported. -2. Builtin functions and types: There are a number of missing functions and - types in `__builtins__` that have not yet been implemented. There are also a - lot of methods on builtin types that are missing. +2. [Builtin functions and types](https://github.com/google/grumpy/wiki/Missing-features#builtins): + There are a number of missing functions and types in `__builtins__` that have + not yet been implemented. There are also a lot of methods on builtin types + that are missing. -3. Standard library: The Python standard library is very large and much of it - is pure Python, so as the language features and builtins get filled out, many - modules will just work. But there are also a number of libraries in CPython - that are C extension modules which will need to be rewritten. +3. [Standard library](https://github.com/google/grumpy/wiki/Missing-features#standard-libraries): + The Python standard library is very large and much of it is pure Python, so + as the language features and builtins get filled out, many modules will + just work. But there are also a number of libraries in CPython that are C + extension modules which will need to be rewritten. 4. C locale support: Go doesn't support locales in the same way that C does. As such, some functionality that is locale-dependent may not currently work the same as in CPython. - To see the status of a particular feature or standard library module, click - [here](https://github.com/google/grumpy/wiki/Missing-Features). - ## Running Grumpy Programs -### Method 1: grumprun: +### Method 1: make run: The simplest way to execute a Grumpy program is to use `make run`, which wraps a shell script called grumprun that takes Python code on stdin and builds and runs @@ -66,33 +67,62 @@ root directory of the Grumpy source code distribution: echo "print 'hello, world'" | make run ``` -### Method 2: grumpc: +### Method 2: grumpc and grumprun: For more complicated programs, you'll want to compile your Python source code to Go using grumpc (the Grumpy compiler) and then build the Go code using `go -build`. First, write a simple .py script: +build`. Since Grumpy programs are statically linked, all the modules in a +program must be findable by the Grumpy toolchain on the GOPATH. Grumpy looks for +Go packages corresponding to Python modules in the \_\_python\_\_ subdirectory +of the GOPATH. By convention, this subdirectory is also used for staging Python +source code, making it similar to the PYTHONPATH. -``` -echo 'print "hello, world"' > hello.py -``` - -Next, build the toolchain and export some environment variables that make the -toolchain work: +The first step is to set up the shell so that the Grumpy toolchain and libraries +can be found. From the root directory of the Grumpy source distribution run: ``` make +export PATH=$PWD/build/bin:$PATH export GOPATH=$PWD/build export PYTHONPATH=$PWD/build/lib/python2.7/site-packages ``` -Finally, compile the Python script and build a binary from it: +You will know things are working if you see the expected output from this +command: ``` -build/bin/grumpc hello.py > hello.go -go build -o hello hello.go +echo 'import sys; print sys.version' | grumprun ``` -Now execute the `./hello` binary to your heart's content. +Next, we will write our simple Python module into the \_\_python\_\_ directory: + +``` +echo 'def hello(): print "hello, world"' > $GOPATH/src/__python__/hello.py +``` + +To build a Go package from our Python script, run the following: + +``` +mkdir -p $GOPATH/src/__python__/hello +grumpc -modname=hello $GOPATH/src/__python__/hello.py > \ + $GOPATH/src/__python__/hello/module.go +``` + +You should now be able to build a Go program that imports the package +"\_\_python\_\_/hello". We can also import this module into Python programs +that are built using grumprun: + +``` +echo 'from hello import hello; hello()' | grumprun +``` + +grumprun is doing a few things under the hood here: + +1. Compiles the given Python code to a dummy Go package, the same way we + produced \_\_python\_\_/hello/module.go above +2. Produces a main Go package that imports the Go package from step 1. and + executes it as our \_\_main\_\_ Python package +3. Executes `go run` on the main package generated in step 2. ## Developing Grumpy @@ -103,7 +133,8 @@ writing, you may need to change one or more of these. Grumpy converts Python programs into Go programs and `grumpc` is the tool responsible for parsing Python code and generating Go code from it. `grumpc` is -written in Python and uses the `ast` module to accomplish parsing. +written in Python and uses the [`pythonparser`](https://github.com/m-labs/pythonparser) +module to accomplish parsing. The grumpc script itself lives at `tools/grumpc`. It is supported by a number of Python modules in the `compiler` subdir. @@ -122,7 +153,7 @@ counterparts in CPython. Much of the Python standard library is written in Python and thus "just works" in Grumpy. These parts of the standard library are copied from CPython 2.7 (possibly with light modifications). For licensing reasons, these files are kept -in the `third_party/stdlib` subdir. +in the `third_party` subdir. The parts of the standard library that cannot be written in pure Python, e.g. file and directory operations, are kept in the `lib` subdir. In CPython these @@ -135,9 +166,13 @@ available in Python. - `compiler`: Python package implementating Python -> Go transcompilation logic. - `lib`: Grumpy-specific Python standard library implementation. - `runtime`: Go source code for the Grumpy runtime library. +- `third_party/ouroboros`: Pure Python standard libraries copied from the + [Ouroboros project](https://github.com/pybee/ouroboros). +- `third_party/pypy`: Pure Python standard libraries copied from PyPy. - `third_party/stdlib`: Pure Python standard libraries copied from CPython. - `tools`: Transcompilation and utility binaries. ## Contact -Questions? Comments? Drop us a line at [grumpy-users@googlegroups.com](https://groups.google.com/forum/#!forum/grumpy-users). +Questions? Comments? Drop us a line at [grumpy-users@googlegroups.com](https://groups.google.com/forum/#!forum/grumpy-users) +or join our [Gitter channel](https://gitter.im/grumpy-devel/Lobby) diff --git a/benchmarks/dict.py b/benchmarks/dict.py index 1610db2c..ae08b8b3 100644 --- a/benchmarks/dict.py +++ b/benchmarks/dict.py @@ -19,6 +19,16 @@ import weetest +def BenchmarkDictCreate(b): + for _ in xrange(b.N): + d = {'one': 1, 'two': 2, 'three': 3} + + +def BenchmarkDictCreateFunc(b): + for _ in xrange(b.N): + d = dict(one=1, two=2, three=3) + + def BenchmarkDictGetItem(b): d = {42: 123} for _ in xrange(b.N): diff --git a/compiler/block.py b/compiler/block.py index 21a1b5a8..423ded2d 100644 --- a/compiler/block.py +++ b/compiler/block.py @@ -16,13 +16,17 @@ """Classes for analyzing and storing the state of Python code blocks.""" +from __future__ import unicode_literals + import abc -import ast import collections import re from grumpy.compiler import expr from grumpy.compiler import util +from grumpy.pythonparser import algorithm +from grumpy.pythonparser import ast +from grumpy.pythonparser import source _non_word_re = re.compile('[^A-Za-z0-9_]') @@ -41,9 +45,8 @@ def __init__(self, name, alias=None): class Loop(object): """Represents a for or while loop within a particular block.""" - def __init__(self, start_label, end_label): - self.start_label = start_label - self.end_label = end_label + def __init__(self, breakvar): + self.breakvar = breakvar class Block(object): @@ -51,19 +54,9 @@ class Block(object): __metaclass__ = abc.ABCMeta - # These are ModuleBlock attributes. Avoid pylint errors for accessing them on - # Block objects by defining them here. - _filename = None - _full_package_name = None - _libroot = None - _lines = None - _runtime = None - _strings = None - imports = None - _future_features = None - - def __init__(self, parent_block, name): - self.parent_block = parent_block + def __init__(self, parent, name): + self.root = parent.root if parent else self + self.parent = parent self.name = name self.free_temps = set() self.used_temps = set() @@ -73,41 +66,6 @@ def __init__(self, parent_block, name): self.loop_stack = [] self.is_generator = False - block = self - while block and not isinstance(block, ModuleBlock): - block = block.parent_block - self._module_block = block - - @property - def full_package_name(self): - # pylint: disable=protected-access - return self._module_block._full_package_name - - @property - def runtime(self): - return self._module_block._runtime # pylint: disable=protected-access - - @property - def libroot(self): - return self._module_block._libroot # pylint: disable=protected-access - - @property - def filename(self): - return self._module_block._filename # pylint: disable=protected-access - - @property - def lines(self): - return self._module_block._lines # pylint: disable=protected-access - - @property - def strings(self): - return self._module_block._strings # pylint: disable=protected-access - - @property - def future_features(self): - # pylint: disable=protected-access - return self._module_block._future_features - @abc.abstractmethod def bind_var(self, writer, name, value): """Writes Go statements for assigning value to named var in this block. @@ -141,30 +99,6 @@ def resolve_name(self, writer, name): """ pass - def add_import(self, name): - """Register the named Go package for import in this block's ModuleBlock. - - add_import walks the block chain to the root ModuleBlock and adds a Package - to its imports dict. - - Args: - name: The fully qualified Go package name. - Returns: - A Package representing the import. - """ - return self.add_native_import('/'.join([self.libroot, name])) - - def add_native_import(self, name): - alias = None - if name == 'grumpy': - name = self.runtime - alias = 'πg' - if name in self._module_block.imports: - return self._module_block.imports[name] - package = Package(name, alias) - self._module_block.imports[name] = package - return package - def genlabel(self, is_checkpoint=False): self.label_count += 1 if is_checkpoint: @@ -189,8 +123,8 @@ def free_temp(self, v): self.used_temps.remove(v) self.free_temps.add(v) - def push_loop(self): - loop = Loop(self.genlabel(), self.genlabel()) + def push_loop(self, breakvar): + loop = Loop(breakvar) self.loop_stack.append(loop) return loop @@ -200,37 +134,25 @@ def pop_loop(self): def top_loop(self): return self.loop_stack[-1] - def intern(self, s): - if len(s) > 64 or _non_word_re.search(s): - return 'πg.NewStr({})'.format(util.go_str(s)) - self.strings.add(s) - return 'ß' + s - def _resolve_global(self, writer, name): result = self.alloc_temp() writer.write_checked_call2( - result, 'πg.ResolveGlobal(πF, {})', self.intern(name)) + result, 'πg.ResolveGlobal(πF, {})', self.root.intern(name)) return result class ModuleBlock(Block): - """Python block for a module. - - Attributes: - imports: A dict mapping fully qualified Go package names to Package objects. - """ - - def __init__(self, full_package_name, runtime, libroot, filename, lines, - future_features): - super(ModuleBlock, self).__init__(None, '') - self._full_package_name = full_package_name - self._runtime = runtime - self._libroot = libroot - self._filename = filename - self._lines = lines - self._strings = set() - self.imports = {} - self._future_features = future_features + """Python block for a module.""" + + def __init__(self, importer, full_package_name, + filename, src, future_features): + Block.__init__(self, None, '') + self.importer = importer + self.full_package_name = full_package_name + self.filename = filename + self.buffer = source.Buffer(src) + self.strings = set() + self.future_features = future_features def bind_var(self, writer, name, value): writer.write_checked_call1( @@ -244,24 +166,31 @@ def del_var(self, writer, name): def resolve_name(self, writer, name): return self._resolve_global(writer, name) + def intern(self, s): + if len(s) > 64 or _non_word_re.search(s): + return 'πg.NewStr({})'.format(util.go_str(s)) + self.strings.add(s) + return 'ß' + s + class ClassBlock(Block): """Python block for a class definition.""" - def __init__(self, parent_block, name, global_vars): - super(ClassBlock, self).__init__(parent_block, name) + def __init__(self, parent, name, global_vars): + Block.__init__(self, parent, name) self.global_vars = global_vars def bind_var(self, writer, name, value): if name in self.global_vars: - return self._module_block.bind_var(writer, name, value) + return self.root.bind_var(writer, name, value) writer.write_checked_call1('πClass.SetItem(πF, {}.ToObject(), {})', - self.intern(name), value) + self.root.intern(name), value) def del_var(self, writer, name): if name in self.global_vars: - return self._module_block.del_var(writer, name) - writer.write_checked_call1('πg.DelVar(πF, πClass, {})', self.intern(name)) + return self.root.del_var(writer, name) + writer.write_checked_call1('πg.DelVar(πF, πClass, {})', + self.root.intern(name)) def resolve_name(self, writer, name): local = 'nil' @@ -269,7 +198,7 @@ def resolve_name(self, writer, name): # Only look for a local in an outer block when name hasn't been declared # global in this block. If it has been declared global then we fallback # straight to the global dict. - block = self.parent_block + block = self.parent while not isinstance(block, ModuleBlock): if isinstance(block, FunctionBlock) and name in block.vars: var = block.vars[name] @@ -277,26 +206,26 @@ def resolve_name(self, writer, name): local = util.adjust_local_name(name) # When it is declared global, prefer it to anything in outer blocks. break - block = block.parent_block + block = block.parent result = self.alloc_temp() writer.write_checked_call2( result, 'πg.ResolveClass(πF, πClass, {}, {})', - local, self.intern(name)) + local, self.root.intern(name)) return result class FunctionBlock(Block): """Python block for a function definition.""" - def __init__(self, parent_block, name, block_vars, is_generator): - super(FunctionBlock, self).__init__(parent_block, name) + def __init__(self, parent, name, block_vars, is_generator): + Block.__init__(self, parent, name) self.vars = block_vars - self.parent_block = parent_block + self.parent = parent self.is_generator = is_generator def bind_var(self, writer, name, value): if self.vars[name].type == Var.TYPE_GLOBAL: - return self._module_block.bind_var(writer, name, value) + return self.root.bind_var(writer, name, value) writer.write('{} = {}'.format(util.adjust_local_name(name), value)) def del_var(self, writer, name): @@ -305,7 +234,7 @@ def del_var(self, writer, name): raise util.ParseError( None, 'cannot delete nonexistent local: {}'.format(name)) if var.type == Var.TYPE_GLOBAL: - return self._module_block.del_var(writer, name) + return self.root.del_var(writer, name) adjusted_name = util.adjust_local_name(name) # Resolve local first to ensure the variable is already bound. writer.write_checked_call1('πg.CheckLocal(πF, {}, {})', @@ -324,7 +253,7 @@ def resolve_name(self, writer, name): util.adjust_local_name(name), util.go_str(name)) return expr.GeneratedLocalVar(name) - block = block.parent_block + block = block.parent return self._resolve_global(writer, name) @@ -349,7 +278,7 @@ def __init__(self, name, var_type, arg_index=None): self.init_expr = None -class BlockVisitor(ast.NodeVisitor): +class BlockVisitor(algorithm.Visitor): """Visits nodes in a function or class to determine block variables.""" # pylint: disable=invalid-name,missing-docstring @@ -397,8 +326,9 @@ def visit_ImportFrom(self, node): self._register_local(alias.asname or alias.name) def visit_With(self, node): - if node.optional_vars: - self._assign_target(node.optional_vars) + for item in node.items: + if item.optional_vars: + self._assign_target(item.optional_vars) self.generic_visit(node) def _assign_target(self, target): @@ -431,14 +361,14 @@ class FunctionBlockVisitor(BlockVisitor): # pylint: disable=invalid-name,missing-docstring def __init__(self, node): - super(FunctionBlockVisitor, self).__init__() + BlockVisitor.__init__(self) self.is_generator = False node_args = node.args - args = [a.id for a in node_args.args] + args = [a.arg for a in node_args.args] if node_args.vararg: - args.append(node_args.vararg) + args.append(node_args.vararg.arg) if node_args.kwarg: - args.append(node_args.kwarg) + args.append(node_args.kwarg.arg) for i, name in enumerate(args): if name in self.vars: msg = "duplicate argument '{}' in function definition".format(name) diff --git a/compiler/block_test.py b/compiler/block_test.py index ec73d04b..63376997 100644 --- a/compiler/block_test.py +++ b/compiler/block_test.py @@ -16,13 +16,15 @@ """Tests Package, Block, BlockVisitor and related classes.""" -import ast +from __future__ import unicode_literals + import textwrap import unittest from grumpy.compiler import block -from grumpy.compiler import stmt +from grumpy.compiler import imputil from grumpy.compiler import util +from grumpy import pythonparser class PackageTest(unittest.TestCase): @@ -39,30 +41,11 @@ def testCreateGrump(self): class BlockTest(unittest.TestCase): - def testAddImport(self): - module_block = _MakeModuleBlock() - func1_block = block.FunctionBlock(module_block, 'func1', {}, False) - func2_block = block.FunctionBlock(func1_block, 'func2', {}, False) - package = func2_block.add_import('foo/bar') - self.assertEqual(package.name, 'grumpy/lib/foo/bar') - self.assertEqual(package.alias, 'π_grumpyΓlibΓfooΓbar') - self.assertEqual(module_block.imports, {'grumpy/lib/foo/bar': package}) - - def testAddImportRepeated(self): - b = _MakeModuleBlock() - package = b.add_import('foo') - self.assertEqual(package.name, 'grumpy/lib/foo') - self.assertEqual(package.alias, 'π_grumpyΓlibΓfoo') - self.assertEqual(b.imports, {'grumpy/lib/foo': package}) - package2 = b.add_import('foo') - self.assertIs(package, package2) - self.assertEqual(b.imports, {'grumpy/lib/foo': package}) - def testLoop(self): b = _MakeModuleBlock() - loop = b.push_loop() + loop = b.push_loop(None) self.assertEqual(loop, b.top_loop()) - inner_loop = b.push_loop() + inner_loop = b.push_loop(None) self.assertEqual(inner_loop, b.top_loop()) b.pop_loop() self.assertEqual(loop, b.top_loop()) @@ -106,7 +89,7 @@ def testResolveName(self): def _ResolveName(self, b, name): writer = util.Writer() b.resolve_name(writer, name) - return writer.out.getvalue() + return writer.getvalue() class BlockVisitorTest(unittest.TestCase): @@ -204,7 +187,7 @@ def testGlobalIsParam(self): visitor.visit, _ParseStmt('global foo')) def testGlobalUsedPriorToDeclaration(self): - node = ast.parse('foo = 42\nglobal foo') + node = pythonparser.parse('foo = 42\nglobal foo') visitor = block.BlockVisitor() self.assertRaisesRegexp(util.ParseError, 'used prior to global declaration', visitor.generic_visit, node) @@ -243,12 +226,13 @@ def testYieldExpr(self): def _MakeModuleBlock(): - return block.ModuleBlock('__main__', 'grumpy', 'grumpy/lib', '', [], - stmt.FutureFeatures()) + importer = imputil.Importer(None, '__main__', '/tmp/foo.py', False) + return block.ModuleBlock(importer, '__main__', '', '', + imputil.FutureFeatures()) def _ParseStmt(stmt_str): - return ast.parse(stmt_str).body[0] + return pythonparser.parse(stmt_str).body[0] if __name__ == '__main__': diff --git a/compiler/expr.py b/compiler/expr.py index c34e8454..bdc72966 100644 --- a/compiler/expr.py +++ b/compiler/expr.py @@ -16,6 +16,8 @@ """Classes representing generated expressions.""" +from __future__ import unicode_literals + import abc from grumpy.compiler import util @@ -79,3 +81,15 @@ def expr(self): nil_expr = GeneratedLiteral('nil') + + +class BlankVar(GeneratedExpr): + def __init__(self): + self.name = '_' + + @property + def expr(self): + return '_' + + +blank_var = BlankVar() diff --git a/compiler/expr_visitor.py b/compiler/expr_visitor.py index a544b8ec..267abdf7 100644 --- a/compiler/expr_visitor.py +++ b/compiler/expr_visitor.py @@ -16,22 +16,26 @@ """Visitor class for traversing Python expressions.""" -import ast +from __future__ import unicode_literals + +import contextlib import textwrap -from grumpy.compiler import block from grumpy.compiler import expr from grumpy.compiler import util +from grumpy.pythonparser import algorithm +from grumpy.pythonparser import ast -class ExprVisitor(ast.NodeVisitor): +class ExprVisitor(algorithm.Visitor): """Builds and returns a Go expression representing the Python nodes.""" # pylint: disable=invalid-name,missing-docstring - def __init__(self, block_, writer): - self.block = block_ - self.writer = writer + def __init__(self, stmt_visitor): + self.stmt_visitor = stmt_visitor + self.block = stmt_visitor.block + self.writer = stmt_visitor.writer def generic_visit(self, node): msg = 'expression node not yet implemented: ' + type(node).__name__ @@ -42,7 +46,7 @@ def visit_Attribute(self, node): attr = self.block.alloc_temp() self.writer.write_checked_call2( attr, 'πg.GetAttr(πF, {}, {}, nil)', - obj.expr, self.block.intern(node.attr)) + obj.expr, self.block.root.intern(node.attr)) return attr def visit_BinOp(self, node): @@ -178,37 +182,55 @@ def visit_Dict(self, node): self.writer.write('{} = {}.ToObject()'.format(result.name, d.expr)) return result + def visit_Set(self, node): + with self.block.alloc_temp('*πg.Set') as s: + self.writer.write('{} = πg.NewSet()'.format(s.name)) + for e in node.elts: + with self.visit(e) as value: + self.writer.write_checked_call2(expr.blank_var, '{}.Add(πF, {})', + s.expr, value.expr) + result = self.block.alloc_temp() + self.writer.write('{} = {}.ToObject()'.format(result.name, s.expr)) + return result + def visit_DictComp(self, node): result = self.block.alloc_temp() - elt = ast.Tuple(elts=[node.key, node.value], context=ast.Load) - with self.visit(ast.GeneratorExp(elt, node.generators)) as gen: + elt = ast.Tuple(elts=[node.key, node.value]) + gen_node = ast.GeneratorExp( + elt=elt, generators=node.generators, loc=node.loc) + with self.visit(gen_node) as gen: self.writer.write_checked_call2( result, 'πg.DictType.Call(πF, πg.Args{{{}}}, nil)', gen.expr) return result def visit_ExtSlice(self, node): result = self.block.alloc_temp() - with self.block.alloc_temp('[]*πg.Object') as dims: - self.writer.write('{} = make([]*πg.Object, {})'.format( - dims.name, len(node.dims))) - for i, dim in enumerate(node.dims): - with self.visit(dim) as s: - self.writer.write('{}[{}] = {}'.format(dims.name, i, s.expr)) - self.writer.write('{} = πg.NewTuple({}...).ToObject()'.format( - result.name, dims.expr)) + if len(node.dims) <= util.MAX_DIRECT_TUPLE: + with contextlib.nested(*(self.visit(d) for d in node.dims)) as dims: + self.writer.write('{} = πg.NewTuple{}({}).ToObject()'.format( + result.name, len(dims), ', '.join(d.expr for d in dims))) + else: + with self.block.alloc_temp('[]*πg.Object') as dims: + self.writer.write('{} = make([]*πg.Object, {})'.format( + dims.name, len(node.dims))) + for i, dim in enumerate(node.dims): + with self.visit(dim) as s: + self.writer.write('{}[{}] = {}'.format(dims.name, i, s.expr)) + self.writer.write('{} = πg.NewTuple({}...).ToObject()'.format( + result.name, dims.expr)) return result def visit_GeneratorExp(self, node): - body = ast.Expr(value=ast.Yield(node.elt), lineno=None) + body = ast.Expr(value=ast.Yield(value=node.elt), loc=node.loc) for comp_node in reversed(node.generators): for if_node in reversed(comp_node.ifs): - body = ast.If(test=if_node, body=[body], orelse=[], lineno=None) # pylint: disable=redefined-variable-type + body = ast.If(test=if_node, body=[body], orelse=[], loc=node.loc) # pylint: disable=redefined-variable-type body = ast.For(target=comp_node.target, iter=comp_node.iter, - body=[body], orelse=[], lineno=None) + body=[body], orelse=[], loc=node.loc) args = ast.arguments(args=[], vararg=None, kwarg=None, defaults=[]) node = ast.FunctionDef(name='', args=args, body=[body]) - gen_func = self.visit_function_inline(node) + gen_func = self.stmt_visitor.visit_function_inline(node) result = self.block.alloc_temp() self.writer.write_checked_call2( result, '{}.Call(πF, nil, nil)', gen_func.expr) @@ -240,10 +262,10 @@ def visit_Index(self, node): return result def visit_Lambda(self, node): - ret = ast.Return(node.body, lineno=node.lineno) + ret = ast.Return(value=node.body, loc=node.loc) func_node = ast.FunctionDef( name='', args=node.args, body=[ret]) - return self.visit_function_inline(func_node) + return self.stmt_visitor.visit_function_inline(func_node) def visit_List(self, node): with self._visit_seq_elts(node.elts) as elems: @@ -254,7 +276,9 @@ def visit_List(self, node): def visit_ListComp(self, node): result = self.block.alloc_temp() - with self.visit(ast.GeneratorExp(node.elt, node.generators)) as gen: + gen_node = ast.GeneratorExp( + elt=node.elt, generators=node.generators, loc=node.loc) + with self.visit(gen_node) as gen: self.writer.write_checked_call2( result, 'πg.ListType.Call(πF, πg.Args{{{}}}, nil)', gen.expr) return result @@ -276,6 +300,8 @@ def visit_Num(self, node): expr_str = expr_str + '.Neg()' elif isinstance(node.n, float): expr_str = 'NewFloat({})'.format(node.n) + elif isinstance(node.n, complex): + expr_str = 'NewComplex(complex({}, {}))'.format(node.n.real, node.n.imag) else: msg = 'number type not yet implemented: ' + type(node.n).__name__ raise util.ParseError(node, msg) @@ -309,14 +335,19 @@ def visit_Str(self, node): expr_str = 'πg.NewUnicode({}).ToObject()'.format( util.go_str(node.s.encode('utf-8'))) else: - expr_str = '{}.ToObject()'.format(self.block.intern(node.s)) + expr_str = '{}.ToObject()'.format(self.block.root.intern(node.s)) return expr.GeneratedLiteral(expr_str) def visit_Tuple(self, node): - with self._visit_seq_elts(node.elts) as elems: - result = self.block.alloc_temp() - self.writer.write('{} = πg.NewTuple({}...).ToObject()'.format( - result.expr, elems.expr)) + result = self.block.alloc_temp() + if len(node.elts) <= util.MAX_DIRECT_TUPLE: + with contextlib.nested(*(self.visit(e) for e in node.elts)) as elts: + self.writer.write('{} = πg.NewTuple{}({}).ToObject()'.format( + result.name, len(elts), ', '.join(e.expr for e in elts))) + else: + with self._visit_seq_elts(node.elts) as elems: + self.writer.write('{} = πg.NewTuple({}...).ToObject()'.format( + result.expr, elems.expr)) return result def visit_UnaryOp(self, node): @@ -358,7 +389,7 @@ def visit_Yield(self, node): ast.Add: 'πg.Add(πF, {lhs}, {rhs})', ast.Div: 'πg.Div(πF, {lhs}, {rhs})', # TODO: Support "from __future__ import division". - ast.FloorDiv: 'πg.Div(πF, {lhs}, {rhs})', + ast.FloorDiv: 'πg.FloorDiv(πF, {lhs}, {rhs})', ast.LShift: 'πg.LShift(πF, {lhs}, {rhs})', ast.Mod: 'πg.Mod(πF, {lhs}, {rhs})', ast.Mult: 'πg.Mul(πF, {lhs}, {rhs})', @@ -378,73 +409,10 @@ def visit_Yield(self, node): _UNARY_OP_TEMPLATES = { ast.Invert: 'πg.Invert(πF, {operand})', + ast.UAdd: 'πg.Pos(πF, {operand})', ast.USub: 'πg.Neg(πF, {operand})', } - def visit_function_inline(self, node): - """Returns an GeneratedExpr for a function with the given body.""" - # First pass collects the names of locals used in this function. Do this in - # a separate pass so that we know whether to resolve a name as a local or a - # global during the second pass. - func_visitor = block.FunctionBlockVisitor(node) - for child in node.body: - func_visitor.visit(child) - func_block = block.FunctionBlock(self.block, node.name, func_visitor.vars, - func_visitor.is_generator) - # TODO: Find a better way to reduce coupling between ExprVisitor and - # StatementVisitor. - from grumpy.compiler import stmt # pylint: disable=g-import-not-at-top - visitor = stmt.StatementVisitor(func_block) - # Indent so that the function body is aligned with the goto labels. - with visitor.writer.indent_block(): - visitor._visit_each(node.body) # pylint: disable=protected-access - - result = self.block.alloc_temp() - with self.block.alloc_temp('[]πg.Param') as func_args: - args = node.args - argc = len(args.args) - self.writer.write('{} = make([]πg.Param, {})'.format( - func_args.expr, argc)) - # The list of defaults only contains args for which a default value is - # specified so pad it with None to make it the same length as args. - defaults = [None] * (argc - len(args.defaults)) + args.defaults - for i, (a, d) in enumerate(zip(args.args, defaults)): - with self.visit(d) if d else expr.nil_expr as default: - tmpl = '$args[$i] = πg.Param{Name: $name, Def: $default}' - self.writer.write_tmpl(tmpl, args=func_args.expr, i=i, - name=util.go_str(a.id), default=default.expr) - flags = [] - if args.vararg: - flags.append('πg.CodeFlagVarArg') - if args.kwarg: - flags.append('πg.CodeFlagKWArg') - # The function object gets written to a temporary writer because we need - # it as an expression that we subsequently bind to some variable. - self.writer.write_tmpl( - '$result = πg.NewFunction(πg.NewCode($name, $filename, $args, ' - '$flags, func(πF *πg.Frame, πArgs []*πg.Object) ' - '(*πg.Object, *πg.BaseException) {', - result=result.name, name=util.go_str(node.name), - filename=util.go_str(self.block.filename), args=func_args.expr, - flags=' | '.join(flags) if flags else 0) - with self.writer.indent_block(): - for var in func_block.vars.values(): - if var.type != block.Var.TYPE_GLOBAL: - fmt = 'var {0} *πg.Object = {1}; _ = {0}' - self.writer.write(fmt.format( - util.adjust_local_name(var.name), var.init_expr)) - self.writer.write_temp_decls(func_block) - if func_block.is_generator: - self.writer.write('return πg.NewGenerator(πF, func(πSent *πg.Object) ' - '(*πg.Object, *πg.BaseException) {') - with self.writer.indent_block(): - self.writer.write_block(func_block, visitor.writer.out.getvalue()) - self.writer.write('}).ToObject(), nil') - else: - self.writer.write_block(func_block, visitor.writer.out.getvalue()) - self.writer.write('}), πF.Globals()).ToObject()') - return result - def _visit_seq_elts(self, elts): result = self.block.alloc_temp('[]*πg.Object') self.writer.write('{} = make([]*πg.Object, {})'.format( diff --git a/compiler/expr_visitor_test.py b/compiler/expr_visitor_test.py index 301a1031..08e392ea 100644 --- a/compiler/expr_visitor_test.py +++ b/compiler/expr_visitor_test.py @@ -16,16 +16,17 @@ """Tests for ExprVisitor.""" -import ast +from __future__ import unicode_literals + import subprocess import textwrap import unittest from grumpy.compiler import block -from grumpy.compiler import expr_visitor +from grumpy.compiler import imputil from grumpy.compiler import shard_test from grumpy.compiler import stmt -from grumpy.compiler import util +from grumpy import pythonparser def _MakeExprTest(expr): @@ -35,11 +36,13 @@ def Test(self): return Test -def _MakeLiteralTest(lit): +def _MakeLiteralTest(lit, expected=None): + if expected is None: + expected = lit def Test(self): - status, output = _GrumpRun('print repr({!r}),'.format(lit)) + status, output = _GrumpRun('print repr({}),'.format(lit)) self.assertEqual(0, status, output) - self.assertEqual(lit, ast.literal_eval(output)) + self.assertEqual(expected, output.strip()) # pylint: disable=eval-used return Test @@ -129,8 +132,10 @@ def foo(a, b=2): testCompareInTuple = _MakeExprTest('1 in (1, 2, 3)') testCompareNotInTuple = _MakeExprTest('10 < 12 not in (1, 2, 3)') - testDictEmpty = _MakeLiteralTest({}) - testDictNonEmpty = _MakeLiteralTest({'foo': 42, 'bar': 43}) + testDictEmpty = _MakeLiteralTest('{}') + testDictNonEmpty = _MakeLiteralTest("{'foo': 42, 'bar': 43}") + + testSetNonEmpty = _MakeLiteralTest("{'foo', 'bar'}", "set(['foo', 'bar'])") testDictCompFor = _MakeExprTest('{x: str(x) for x in range(3)}') testDictCompForIf = _MakeExprTest( @@ -155,8 +160,8 @@ def foo(a, b=2): testLambda = _MakeExprTest('(lambda *args: args)(1, 2, 3)') testLambda = _MakeExprTest('(lambda **kwargs: kwargs)(x="foo", y="bar")') - testListEmpty = _MakeLiteralTest([]) - testListNonEmpty = _MakeLiteralTest([1, 2]) + testListEmpty = _MakeLiteralTest('[]') + testListNonEmpty = _MakeLiteralTest('[1, 2]') testListCompFor = _MakeExprTest('[int(x) for x in "123"]') testListCompForIf = _MakeExprTest('[x / 3 for x in range(10) if x % 3]') @@ -177,16 +182,18 @@ def foo(): foo()""") self.assertEqual((0, ''), _GrumpRun(code)) - testNumInt = _MakeLiteralTest(42) - testNumLong = _MakeLiteralTest(42L) - testNumIntLarge = _MakeLiteralTest(12345678901234567890) - testNumFloat = _MakeLiteralTest(102.1) - testNumFloatNoDecimal = _MakeLiteralTest(5.) - testNumFloatOnlyDecimal = _MakeLiteralTest(.5) - testNumFloatSci = _MakeLiteralTest(1e6) - testNumFloatSciCap = _MakeLiteralTest(1E6) - testNumFloatSciCapPlus = _MakeLiteralTest(1E+6) - testNumFloatSciMinus = _MakeLiteralTest(1e-6) + testNumInt = _MakeLiteralTest('42') + testNumLong = _MakeLiteralTest('42L') + testNumIntLarge = _MakeLiteralTest('12345678901234567890', + '12345678901234567890L') + testNumFloat = _MakeLiteralTest('102.1') + testNumFloatOnlyDecimal = _MakeLiteralTest('.5', '0.5') + testNumFloatNoDecimal = _MakeLiteralTest('5.', '5.0') + testNumFloatSci = _MakeLiteralTest('1e6', '1000000.0') + testNumFloatSciCap = _MakeLiteralTest('1E6', '1000000.0') + testNumFloatSciCapPlus = _MakeLiteralTest('1E+6', '1000000.0') + testNumFloatSciMinus = _MakeLiteralTest('1e-06') + testNumComplex = _MakeLiteralTest('3j') testSubscriptDictStr = _MakeExprTest('{"foo": 42}["foo"]') testSubscriptListInt = _MakeExprTest('[1, 2, 3][2]') @@ -200,37 +207,33 @@ def foo(): testSubscriptMultiDimSlice = _MakeSliceTest( "'foo','bar':'baz':'qux'", "('foo', slice('bar', 'baz', 'qux'))") - testStrEmpty = _MakeLiteralTest('') - testStrAscii = _MakeLiteralTest('abc') - testStrUtf8 = _MakeLiteralTest('\tfoo\n\xcf\x80') - testStrQuoted = _MakeLiteralTest('"foo"') - testStrUtf16 = _MakeLiteralTest(u'\u0432\u043e\u043b\u043d') + testStrEmpty = _MakeLiteralTest("''") + testStrAscii = _MakeLiteralTest("'abc'") + testStrUtf8 = _MakeLiteralTest(r"'\tfoo\n\xcf\x80'") + testStrQuoted = _MakeLiteralTest('\'"foo"\'', '\'"foo"\'') + testStrUtf16 = _MakeLiteralTest("u'\\u0432\\u043e\\u043b\\u043d'") - testTupleEmpty = _MakeLiteralTest(()) - testTupleNonEmpty = _MakeLiteralTest((1, 2, 3)) + testTupleEmpty = _MakeLiteralTest('()') + testTupleNonEmpty = _MakeLiteralTest('(1, 2, 3)') testUnaryOpNot = _MakeExprTest('not True') testUnaryOpInvert = _MakeExprTest('~4') - - def testUnaryOpNotImplemented(self): - self.assertRaisesRegexp(util.ParseError, 'unary op not implemented', - _ParseAndVisitExpr, '+foo') + testUnaryOpPos = _MakeExprTest('+4') def _MakeModuleBlock(): - return block.ModuleBlock('__main__', 'grumpy', 'grumpy/lib', '', [], - stmt.FutureFeatures()) + return block.ModuleBlock(None, '__main__', '', '', + imputil.FutureFeatures()) def _ParseExpr(expr): - return ast.parse(expr).body[0].value + return pythonparser.parse(expr).body[0].value def _ParseAndVisitExpr(expr): - writer = util.Writer() - visitor = expr_visitor.ExprVisitor(_MakeModuleBlock(), writer) - visitor.visit(_ParseExpr(expr)) - return writer.out.getvalue() + visitor = stmt.StatementVisitor(_MakeModuleBlock()) + visitor.visit_expr(_ParseExpr(expr)) + return visitor.writer.getvalue() def _GrumpRun(cmd): diff --git a/compiler/imputil.py b/compiler/imputil.py new file mode 100644 index 00000000..811a6555 --- /dev/null +++ b/compiler/imputil.py @@ -0,0 +1,311 @@ +# coding=utf-8 + +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functionality for importing modules in Grumpy.""" + + +from __future__ import unicode_literals + +import collections +import functools +import os +import os.path + +from grumpy.compiler import util +from grumpy import pythonparser +from grumpy.pythonparser import algorithm +from grumpy.pythonparser import ast + + +_NATIVE_MODULE_PREFIX = '__go__/' + + +class Import(object): + """Represents a single module import and all its associated bindings. + + Each import pertains to a single module that is imported. Thus one import + statement may produce multiple Import objects. E.g. "import foo, bar" makes + an Import object for module foo and another one for module bar. + """ + + Binding = collections.namedtuple('Binding', ('bind_type', 'alias', 'value')) + + MODULE = "" + MEMBER = "" + + def __init__(self, name, script=None, is_native=False): + self.name = name + self.script = script + self.is_native = is_native + self.bindings = [] + + def add_binding(self, bind_type, alias, value): + self.bindings.append(Import.Binding(bind_type, alias, value)) + + +class Importer(algorithm.Visitor): + """Visits import nodes and produces corresponding Import objects.""" + + # pylint: disable=invalid-name,missing-docstring,no-init + + def __init__(self, gopath, modname, script, absolute_import): + self.pathdirs = [] + if gopath: + self.pathdirs.extend(os.path.join(d, 'src', '__python__') + for d in gopath.split(os.pathsep)) + dirname, basename = os.path.split(script) + if basename == '__init__.py': + self.package_dir = dirname + self.package_name = modname + elif (modname.find('.') != -1 and + os.path.isfile(os.path.join(dirname, '__init__.py'))): + self.package_dir = dirname + self.package_name = modname[:modname.rfind('.')] + else: + self.package_dir = None + self.package_name = None + self.absolute_import = absolute_import + + def generic_visit(self, node): + raise ValueError('Import cannot visit {} node'.format(type(node).__name__)) + + def visit_Import(self, node): + imports = [] + for alias in node.names: + if alias.name.startswith(_NATIVE_MODULE_PREFIX): + imp = Import(alias.name, is_native=True) + asname = alias.asname if alias.asname else alias.name.split('/')[-1] + imp.add_binding(Import.MODULE, asname, 0) + else: + imp = self._resolve_import(node, alias.name) + if alias.asname: + imp.add_binding(Import.MODULE, alias.asname, imp.name.count('.')) + else: + parts = alias.name.split('.') + imp.add_binding(Import.MODULE, parts[0], + imp.name.count('.') - len(parts) + 1) + imports.append(imp) + return imports + + def visit_ImportFrom(self, node): + if any(a.name == '*' for a in node.names): + raise util.ImportError(node, 'wildcard member import is not implemented') + + if not node.level and node.module == '__future__': + return [] + + if not node.level and node.module.startswith(_NATIVE_MODULE_PREFIX): + imp = Import(node.module, is_native=True) + for alias in node.names: + asname = alias.asname or alias.name + imp.add_binding(Import.MEMBER, asname, alias.name) + return [imp] + + imports = [] + if not node.module: + # Import of the form 'from .. import foo, bar'. All named imports must be + # modules, not module members. + for alias in node.names: + imp = self._resolve_relative_import(node.level, node, alias.name) + imp.add_binding(Import.MODULE, alias.asname or alias.name, + imp.name.count('.')) + imports.append(imp) + return imports + + member_imp = None + for alias in node.names: + asname = alias.asname or alias.name + if node.level: + resolver = functools.partial(self._resolve_relative_import, node.level) + else: + resolver = self._resolve_import + try: + imp = resolver(node, '{}.{}'.format(node.module, alias.name)) + except util.ImportError: + # A member (not a submodule) is being imported, so bind it. + if not member_imp: + member_imp = resolver(node, node.module) + imports.append(member_imp) + member_imp.add_binding(Import.MEMBER, asname, alias.name) + else: + # Imported name is a submodule within a package, so bind that module. + imp.add_binding(Import.MODULE, asname, imp.name.count('.')) + imports.append(imp) + return imports + + def _resolve_import(self, node, modname): + if not self.absolute_import and self.package_dir: + script = find_script(self.package_dir, modname) + if script: + return Import('{}.{}'.format(self.package_name, modname), script) + for dirname in self.pathdirs: + script = find_script(dirname, modname) + if script: + return Import(modname, script) + raise util.ImportError(node, 'no such module: {}'.format(modname)) + + def _resolve_relative_import(self, level, node, modname): + if not self.package_dir: + raise util.ImportError(node, 'attempted relative import in non-package') + uplevel = level - 1 + if uplevel > self.package_name.count('.'): + raise util.ImportError( + node, 'attempted relative import beyond toplevel package') + dirname = os.path.normpath(os.path.join( + self.package_dir, *(['..'] * uplevel))) + script = find_script(dirname, modname) + if not script: + raise util.ImportError(node, 'no such module: {}'.format(modname)) + parts = self.package_name.split('.') + return Import('.'.join(parts[:len(parts)-uplevel]) + '.' + modname, script) + + +class _ImportCollector(algorithm.Visitor): + + # pylint: disable=invalid-name + + def __init__(self, importer, future_node): + self.importer = importer + self.future_node = future_node + self.imports = [] + + def visit_Import(self, node): + self.imports.extend(self.importer.visit(node)) + + def visit_ImportFrom(self, node): + if node.module == '__future__': + if node != self.future_node: + raise util.LateFutureError(node) + return + self.imports.extend(self.importer.visit(node)) + + +def collect_imports(modname, script, gopath): + with open(script) as py_file: + py_contents = py_file.read() + mod = pythonparser.parse(py_contents) + future_node, future_features = parse_future_features(mod) + importer = Importer(gopath, modname, script, future_features.absolute_import) + collector = _ImportCollector(importer, future_node) + collector.visit(mod) + return collector.imports + + +def calculate_transitive_deps(modname, script, gopath): + """Determines all modules that script transitively depends upon.""" + deps = set() + def calc(modname, script): + if modname in deps: + return + deps.add(modname) + for imp in collect_imports(modname, script, gopath): + if imp.is_native: + deps.add(imp.name) + continue + parts = imp.name.split('.') + calc(imp.name, imp.script) + if len(parts) == 1: + continue + # For submodules, the parent packages are also deps. + package_dir, filename = os.path.split(imp.script) + if filename == '__init__.py': + package_dir = os.path.dirname(package_dir) + for i in xrange(len(parts) - 1, 0, -1): + modname = '.'.join(parts[:i]) + script = os.path.join(package_dir, '__init__.py') + calc(modname, script) + package_dir = os.path.dirname(package_dir) + calc(modname, script) + deps.remove(modname) + return deps + + +def find_script(dirname, name): + prefix = os.path.join(dirname, name.replace('.', os.sep)) + script = prefix + '.py' + if os.path.isfile(script): + return script + script = os.path.join(prefix, '__init__.py') + if os.path.isfile(script): + return script + return None + + +_FUTURE_FEATURES = ( + 'absolute_import', + 'division', + 'print_function', + 'unicode_literals', +) + +_IMPLEMENTED_FUTURE_FEATURES = ( + 'absolute_import', + 'print_function', + 'unicode_literals' +) + +# These future features are already in the language proper as of 2.6, so +# importing them via __future__ has no effect. +_REDUNDANT_FUTURE_FEATURES = ('generators', 'with_statement', 'nested_scopes') + + +class FutureFeatures(object): + """Spec for future feature flags imported by a module.""" + + def __init__(self, absolute_import=False, division=False, + print_function=False, unicode_literals=False): + self.absolute_import = absolute_import + self.division = division + self.print_function = print_function + self.unicode_literals = unicode_literals + + +def _make_future_features(node): + """Processes a future import statement, returning set of flags it defines.""" + assert isinstance(node, ast.ImportFrom) + assert node.module == '__future__' + features = FutureFeatures() + for alias in node.names: + name = alias.name + if name in _FUTURE_FEATURES: + if name not in _IMPLEMENTED_FUTURE_FEATURES: + msg = 'future feature {} not yet implemented by grumpy'.format(name) + raise util.ParseError(node, msg) + setattr(features, name, True) + elif name == 'braces': + raise util.ParseError(node, 'not a chance') + elif name not in _REDUNDANT_FUTURE_FEATURES: + msg = 'future feature {} is not defined'.format(name) + raise util.ParseError(node, msg) + return features + + +def parse_future_features(mod): + """Accumulates a set of flags for the compiler __future__ imports.""" + assert isinstance(mod, ast.Module) + found_docstring = False + for node in mod.body: + if isinstance(node, ast.ImportFrom): + if node.module == '__future__': + return node, _make_future_features(node) + break + elif isinstance(node, ast.Expr) and not found_docstring: + if not isinstance(node.value, ast.Str): + break + found_docstring = True + else: + break + return None, FutureFeatures() diff --git a/compiler/imputil_test.py b/compiler/imputil_test.py new file mode 100644 index 00000000..600afdf0 --- /dev/null +++ b/compiler/imputil_test.py @@ -0,0 +1,355 @@ +# coding=utf-8 + +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests ImportVisitor and related classes.""" + +from __future__ import unicode_literals + +import copy +import os +import shutil +import tempfile +import textwrap +import unittest + +from grumpy.compiler import imputil +from grumpy.compiler import util +from grumpy import pythonparser + + +class ImportVisitorTest(unittest.TestCase): + + _PATH_SPEC = { + 'foo.py': None, + 'qux.py': None, + 'bar/': { + 'fred/': { + '__init__.py': None, + 'quux.py': None, + }, + '__init__.py': None, + 'baz.py': None, + 'foo.py': None, + }, + 'baz.py': None, + } + + def setUp(self): + self.rootdir = tempfile.mkdtemp() + self.pydir = os.path.join(self.rootdir, 'src', '__python__') + self._materialize_tree( + self.rootdir, {'src/': {'__python__/': self._PATH_SPEC}}) + foo_script = os.path.join(self.rootdir, 'foo.py') + self.importer = imputil.Importer(self.rootdir, 'foo', foo_script, False) + bar_script = os.path.join(self.pydir, 'bar', '__init__.py') + self.bar_importer = imputil.Importer( + self.rootdir, 'bar', bar_script, False) + fred_script = os.path.join(self.pydir, 'bar', 'fred', '__init__.py') + self.fred_importer = imputil.Importer( + self.rootdir, 'bar.fred', fred_script, False) + + self.foo_import = imputil.Import( + 'foo', os.path.join(self.pydir, 'foo.py')) + self.qux_import = imputil.Import( + 'qux', os.path.join(self.pydir, 'qux.py')) + self.bar_import = imputil.Import( + 'bar', os.path.join(self.pydir, 'bar/__init__.py')) + self.fred_import = imputil.Import( + 'bar.fred', os.path.join(self.pydir, 'bar/fred/__init__.py')) + self.quux_import = imputil.Import( + 'bar.fred.quux', os.path.join(self.pydir, 'bar/fred/quux.py')) + self.baz2_import = imputil.Import( + 'bar.baz', os.path.join(self.pydir, 'bar/baz.py')) + self.foo2_import = imputil.Import( + 'bar.foo', os.path.join(self.pydir, 'bar/foo.py')) + self.baz_import = imputil.Import( + 'baz', os.path.join(self.pydir, 'baz.py')) + + def tearDown(self): + shutil.rmtree(self.rootdir) + + def testImportEmptyPath(self): + importer = imputil.Importer(None, 'foo', 'foo.py', False) + self.assertRaises(util.ImportError, importer.visit, + pythonparser.parse('import bar').body[0]) + + def testImportTopLevelModule(self): + imp = copy.deepcopy(self.qux_import) + imp.add_binding(imputil.Import.MODULE, 'qux', 0) + self._check_imports('import qux', [imp]) + + def testImportTopLevelPackage(self): + imp = copy.deepcopy(self.bar_import) + imp.add_binding(imputil.Import.MODULE, 'bar', 0) + self._check_imports('import bar', [imp]) + + def testImportPackageModuleAbsolute(self): + imp = copy.deepcopy(self.baz2_import) + imp.add_binding(imputil.Import.MODULE, 'bar', 0) + self._check_imports('import bar.baz', [imp]) + + def testImportFromSubModule(self): + imp = copy.deepcopy(self.baz2_import) + imp.add_binding(imputil.Import.MODULE, 'baz', 1) + self._check_imports('from bar import baz', [imp]) + + def testImportPackageModuleRelative(self): + imp = copy.deepcopy(self.baz2_import) + imp.add_binding(imputil.Import.MODULE, 'baz', 1) + got = self.bar_importer.visit(pythonparser.parse('import baz').body[0]) + self._assert_imports_equal([imp], got) + + def testImportPackageModuleRelativeFromSubModule(self): + imp = copy.deepcopy(self.baz2_import) + imp.add_binding(imputil.Import.MODULE, 'baz', 1) + foo_script = os.path.join(self.pydir, 'bar', 'foo.py') + importer = imputil.Importer(self.rootdir, 'bar.foo', foo_script, False) + got = importer.visit(pythonparser.parse('import baz').body[0]) + self._assert_imports_equal([imp], got) + + def testImportPackageModuleAbsoluteImport(self): + imp = copy.deepcopy(self.baz_import) + imp.add_binding(imputil.Import.MODULE, 'baz', 0) + bar_script = os.path.join(self.pydir, 'bar', '__init__.py') + importer = imputil.Importer(self.rootdir, 'bar', bar_script, True) + got = importer.visit(pythonparser.parse('import baz').body[0]) + self._assert_imports_equal([imp], got) + + def testImportMultiple(self): + imp1 = copy.deepcopy(self.foo_import) + imp1.add_binding(imputil.Import.MODULE, 'foo', 0) + imp2 = copy.deepcopy(self.baz2_import) + imp2.add_binding(imputil.Import.MODULE, 'bar', 0) + self._check_imports('import foo, bar.baz', [imp1, imp2]) + + def testImportAs(self): + imp = copy.deepcopy(self.foo_import) + imp.add_binding(imputil.Import.MODULE, 'bar', 0) + self._check_imports('import foo as bar', [imp]) + + def testImportFrom(self): + imp = copy.deepcopy(self.baz2_import) + imp.add_binding(imputil.Import.MODULE, 'baz', 1) + self._check_imports('from bar import baz', [imp]) + + def testImportFromMember(self): + imp = copy.deepcopy(self.foo_import) + imp.add_binding(imputil.Import.MEMBER, 'bar', 'bar') + self._check_imports('from foo import bar', [imp]) + + def testImportFromMultiple(self): + imp1 = copy.deepcopy(self.baz2_import) + imp1.add_binding(imputil.Import.MODULE, 'baz', 1) + imp2 = copy.deepcopy(self.foo2_import) + imp2.add_binding(imputil.Import.MODULE, 'foo', 1) + self._check_imports('from bar import baz, foo', [imp1, imp2]) + + def testImportFromMixedMembers(self): + imp1 = copy.deepcopy(self.bar_import) + imp1.add_binding(imputil.Import.MEMBER, 'qux', 'qux') + imp2 = copy.deepcopy(self.baz2_import) + imp2.add_binding(imputil.Import.MODULE, 'baz', 1) + self._check_imports('from bar import qux, baz', [imp1, imp2]) + + def testImportFromAs(self): + imp = copy.deepcopy(self.baz2_import) + imp.add_binding(imputil.Import.MODULE, 'qux', 1) + self._check_imports('from bar import baz as qux', [imp]) + + def testImportFromAsMembers(self): + imp = copy.deepcopy(self.foo_import) + imp.add_binding(imputil.Import.MEMBER, 'baz', 'bar') + self._check_imports('from foo import bar as baz', [imp]) + + def testImportFromWildcardRaises(self): + self.assertRaises(util.ImportError, self.importer.visit, + pythonparser.parse('from foo import *').body[0]) + + def testImportFromFuture(self): + self._check_imports('from __future__ import print_function', []) + + def testImportFromNative(self): + imp = imputil.Import('__go__/fmt', is_native=True) + imp.add_binding(imputil.Import.MEMBER, 'Printf', 'Printf') + self._check_imports('from "__go__/fmt" import Printf', [imp]) + + def testImportFromNativeMultiple(self): + imp = imputil.Import('__go__/fmt', is_native=True) + imp.add_binding(imputil.Import.MEMBER, 'Printf', 'Printf') + imp.add_binding(imputil.Import.MEMBER, 'Println', 'Println') + self._check_imports('from "__go__/fmt" import Printf, Println', [imp]) + + def testImportFromNativeAs(self): + imp = imputil.Import('__go__/fmt', is_native=True) + imp.add_binding(imputil.Import.MEMBER, 'foo', 'Printf') + self._check_imports('from "__go__/fmt" import Printf as foo', [imp]) + + def testRelativeImportNonPackage(self): + self.assertRaises(util.ImportError, self.importer.visit, + pythonparser.parse('from . import bar').body[0]) + + def testRelativeImportBeyondTopLevel(self): + self.assertRaises(util.ImportError, self.bar_importer.visit, + pythonparser.parse('from .. import qux').body[0]) + + def testRelativeModuleNoExist(self): + self.assertRaises(util.ImportError, self.bar_importer.visit, + pythonparser.parse('from . import qux').body[0]) + + def testRelativeModule(self): + imp = copy.deepcopy(self.foo2_import) + imp.add_binding(imputil.Import.MODULE, 'foo', 1) + node = pythonparser.parse('from . import foo').body[0] + self._assert_imports_equal([imp], self.bar_importer.visit(node)) + + def testRelativeModuleFromSubModule(self): + imp = copy.deepcopy(self.foo2_import) + imp.add_binding(imputil.Import.MODULE, 'foo', 1) + baz_script = os.path.join(self.pydir, 'bar', 'baz.py') + importer = imputil.Importer(self.rootdir, 'bar.baz', baz_script, False) + node = pythonparser.parse('from . import foo').body[0] + self._assert_imports_equal([imp], importer.visit(node)) + + def testRelativeModuleMember(self): + imp = copy.deepcopy(self.foo2_import) + imp.add_binding(imputil.Import.MEMBER, 'qux', 'qux') + node = pythonparser.parse('from .foo import qux').body[0] + self._assert_imports_equal([imp], self.bar_importer.visit(node)) + + def testRelativeModuleMemberMixed(self): + imp1 = copy.deepcopy(self.fred_import) + imp1.add_binding(imputil.Import.MEMBER, 'qux', 'qux') + imp2 = copy.deepcopy(self.quux_import) + imp2.add_binding(imputil.Import.MODULE, 'quux', 2) + node = pythonparser.parse('from .fred import qux, quux').body[0] + self._assert_imports_equal([imp1, imp2], self.bar_importer.visit(node)) + + def testRelativeUpLevel(self): + imp = copy.deepcopy(self.foo2_import) + imp.add_binding(imputil.Import.MODULE, 'foo', 1) + node = pythonparser.parse('from .. import foo').body[0] + self._assert_imports_equal([imp], self.fred_importer.visit(node)) + + def testRelativeUpLevelMember(self): + imp = copy.deepcopy(self.foo2_import) + imp.add_binding(imputil.Import.MEMBER, 'qux', 'qux') + node = pythonparser.parse('from ..foo import qux').body[0] + self._assert_imports_equal([imp], self.fred_importer.visit(node)) + + def _check_imports(self, stmt, want): + got = self.importer.visit(pythonparser.parse(stmt).body[0]) + self._assert_imports_equal(want, got) + + def _assert_imports_equal(self, want, got): + self.assertEqual([imp.__dict__ for imp in want], + [imp.__dict__ for imp in got]) + + def _materialize_tree(self, dirname, spec): + for name, sub_spec in spec.iteritems(): + if name.endswith('/'): + subdir = os.path.join(dirname, name[:-1]) + os.mkdir(subdir) + self._materialize_tree(subdir, sub_spec) + else: + with open(os.path.join(dirname, name), 'w'): + pass + + +class MakeFutureFeaturesTest(unittest.TestCase): + + def testImportFromFuture(self): + testcases = [ + ('from __future__ import print_function', + imputil.FutureFeatures(print_function=True)), + ('from __future__ import generators', imputil.FutureFeatures()), + ('from __future__ import generators, print_function', + imputil.FutureFeatures(print_function=True)), + ] + + for tc in testcases: + source, want = tc + mod = pythonparser.parse(textwrap.dedent(source)) + node = mod.body[0] + got = imputil._make_future_features(node) # pylint: disable=protected-access + self.assertEqual(want.__dict__, got.__dict__) + + def testImportFromFutureParseError(self): + testcases = [ + # NOTE: move this group to testImportFromFuture as they are implemented + # by grumpy + ('from __future__ import division', + r'future feature \w+ not yet implemented'), + ('from __future__ import braces', 'not a chance'), + ('from __future__ import nonexistant_feature', + r'future feature \w+ is not defined'), + ] + + for tc in testcases: + source, want_regexp = tc + mod = pythonparser.parse(source) + node = mod.body[0] + self.assertRaisesRegexp(util.ParseError, want_regexp, + imputil._make_future_features, node) # pylint: disable=protected-access + + +class ParseFutureFeaturesTest(unittest.TestCase): + + def testFutureFeatures(self): + testcases = [ + ('from __future__ import print_function', + imputil.FutureFeatures(print_function=True)), + ("""\ + "module docstring" + + from __future__ import print_function + """, imputil.FutureFeatures(print_function=True)), + ("""\ + "module docstring" + + from __future__ import print_function, with_statement + from __future__ import nested_scopes + """, imputil.FutureFeatures(print_function=True)), + ('from __future__ import absolute_import', + imputil.FutureFeatures(absolute_import=True)), + ('from __future__ import absolute_import, print_function', + imputil.FutureFeatures(absolute_import=True, print_function=True)), + ('foo = 123\nfrom __future__ import print_function', + imputil.FutureFeatures()), + ('import os\nfrom __future__ import print_function', + imputil.FutureFeatures()), + ] + + for tc in testcases: + source, want = tc + mod = pythonparser.parse(textwrap.dedent(source)) + _, got = imputil.parse_future_features(mod) + self.assertEqual(want.__dict__, got.__dict__) + + def testUnimplementedFutureRaises(self): + mod = pythonparser.parse('from __future__ import division') + msg = 'future feature division not yet implemented by grumpy' + self.assertRaisesRegexp(util.ParseError, msg, + imputil.parse_future_features, mod) + + def testUndefinedFutureRaises(self): + mod = pythonparser.parse('from __future__ import foo') + self.assertRaisesRegexp( + util.ParseError, 'future feature foo is not defined', + imputil.parse_future_features, mod) + + +if __name__ == '__main__': + unittest.main() diff --git a/compiler/shard_test.py b/compiler/shard_test.py index 597815be..97d2943d 100644 --- a/compiler/shard_test.py +++ b/compiler/shard_test.py @@ -14,6 +14,8 @@ """Wrapper for unit tests that loads a subset of all test methods.""" +from __future__ import unicode_literals + import argparse import random import re diff --git a/compiler/stmt.py b/compiler/stmt.py index 7bcda879..8daac1db 100644 --- a/compiler/stmt.py +++ b/compiler/stmt.py @@ -16,17 +16,20 @@ """Visitor class for traversing Python statements.""" -import ast +from __future__ import unicode_literals + import string import textwrap from grumpy.compiler import block from grumpy.compiler import expr from grumpy.compiler import expr_visitor +from grumpy.compiler import imputil from grumpy.compiler import util +from grumpy.pythonparser import algorithm +from grumpy.pythonparser import ast -_NATIVE_MODULE_PREFIX = '__go__.' _NATIVE_TYPE_PREFIX = 'type_' # Partial list of known vcs for go module import @@ -40,115 +43,29 @@ _nil_expr = expr.nil_expr -# Parser flags, set on 'from __future__ import *', see parser_flags on -# StatementVisitor below. Note these have the same values as CPython. -FUTURE_DIVISION = 0x2000 -FUTURE_ABSOLUTE_IMPORT = 0x4000 -FUTURE_PRINT_FUNCTION = 0x10000 -FUTURE_UNICODE_LITERALS = 0x20000 - -# Names for future features in 'from __future__ import *'. Map from name in the -# import statement to a tuple of the flag for parser, and whether we've (grumpy) -# implemented the feature yet. -future_features = { - "division": (FUTURE_DIVISION, False), - "absolute_import": (FUTURE_ABSOLUTE_IMPORT, False), - "print_function": (FUTURE_PRINT_FUNCTION, True), - "unicode_literals": (FUTURE_UNICODE_LITERALS, False), -} - -# These future features are already in the language proper as of 2.6, so -# importing them via __future__ has no effect. -redundant_future_features = ["generators", "with_statement", "nested_scopes"] - -late_future = 'from __future__ imports must occur at the beginning of the file' - - -def import_from_future(node): - """Processes a future import statement, returning set of flags it defines.""" - assert isinstance(node, ast.ImportFrom) - assert node.module == '__future__' - flags = 0 - for alias in node.names: - name = alias.name - if name in future_features: - flag, implemented = future_features[name] - if not implemented: - msg = 'future feature {} not yet implemented by grumpy'.format(name) - raise util.ParseError(node, msg) - flags |= flag - elif name == 'braces': - raise util.ParseError(node, 'not a chance') - elif name not in redundant_future_features: - msg = 'future feature {} is not defined'.format(name) - raise util.ParseError(node, msg) - return flags - - -class FutureFeatures(object): - def __init__(self): - self.parser_flags = 0 - self.future_lineno = 0 - - -def visit_future(node): - """Accumulates a set of compiler flags for the compiler __future__ imports. - - Returns an instance of FutureFeatures which encapsulates the flags and the - line number of the last valid future import parsed. A downstream parser can - use the latter to detect invalid future imports that appear too late in the - file. - """ - # If this is the module node, do an initial pass through the module body's - # statements to detect future imports and process their directives (i.e., - # set compiler flags), and detect ones that don't appear at the beginning of - # the file. The only things that can proceed a future statement are other - # future statements and/or a doc string. - assert isinstance(node, ast.Module) - ff = FutureFeatures() - done = False - found_docstring = False - for node in node.body: - if isinstance(node, ast.ImportFrom): - modname = node.module - if modname == '__future__': - if done: - raise util.ParseError(node, late_future) - ff.parser_flags |= import_from_future(node) - ff.future_lineno = node.lineno - else: - done = True - elif isinstance(node, ast.Expr) and not found_docstring: - e = node.value - if not isinstance(e, ast.Str): # pylint: disable=simplifiable-if-statement - done = True - else: - found_docstring = True - else: - done = True - return ff - - -class StatementVisitor(ast.NodeVisitor): +class StatementVisitor(algorithm.Visitor): """Outputs Go statements to a Writer for the given Python nodes.""" # pylint: disable=invalid-name,missing-docstring - def __init__(self, block_): + def __init__(self, block_, future_node=None): self.block = block_ - self.future_features = self.block.future_features or FutureFeatures() + self.future_node = future_node self.writer = util.Writer() - self.expr_visitor = expr_visitor.ExprVisitor(self.block, self.writer) + self.expr_visitor = expr_visitor.ExprVisitor(self) def generic_visit(self, node): msg = 'node not yet implemented: {}'.format(type(node).__name__) raise util.ParseError(node, msg) + def visit_expr(self, node): + return self.expr_visitor.visit(node) + def visit_Assert(self, node): self._write_py_context(node.lineno) # TODO: Only evaluate msg if cond is false. - with self.expr_visitor.visit(node.msg) if node.msg else _nil_expr as msg,\ - self.expr_visitor.visit(node.test) as cond: + with self.visit_expr(node.msg) if node.msg else _nil_expr as msg,\ + self.visit_expr(node.test) as cond: self.writer.write_checked_call1( 'πg.Assert(πF, {}, {})', cond.expr, msg.expr) @@ -158,8 +75,8 @@ def visit_AugAssign(self, node): fmt = 'augmented assignment op not implemented: {}' raise util.ParseError(node, fmt.format(op_type.__name__)) self._write_py_context(node.lineno) - with self.expr_visitor.visit(node.target) as target,\ - self.expr_visitor.visit(node.value) as value,\ + with self.visit_expr(node.target) as target,\ + self.visit_expr(node.value) as value,\ self.block.alloc_temp() as temp: self.writer.write_checked_call2( temp, StatementVisitor._AUG_ASSIGN_TEMPLATES[op_type], @@ -168,7 +85,7 @@ def visit_AugAssign(self, node): def visit_Assign(self, node): self._write_py_context(node.lineno) - with self.expr_visitor.visit(node.value) as value: + with self.visit_expr(node.value) as value: for target in node.targets: self._tie_target(target, value.expr) @@ -176,7 +93,9 @@ def visit_Break(self, node): if not self.block.loop_stack: raise util.ParseError(node, "'break' not in loop") self._write_py_context(node.lineno) - self.writer.write('goto Label{}'.format(self.block.top_loop().end_label)) + self.writer.write_tmpl(textwrap.dedent("""\ + $breakvar = true + continue"""), breakvar=self.block.top_loop().breakvar.name) def visit_ClassDef(self, node): # Since we only care about global vars, we end up throwing away the locals @@ -189,7 +108,7 @@ def visit_ClassDef(self, node): if v.type == block.Var.TYPE_GLOBAL} # Visit all the statements inside body of the class definition. body_visitor = StatementVisitor(block.ClassBlock( - self.block, node.name, global_vars)) + self.block, node.name, global_vars), self.future_node) # Indent so that the function body is aligned with the goto labels. with body_visitor.writer.indent_block(): body_visitor._visit_each(node.body) # pylint: disable=protected-access @@ -202,39 +121,41 @@ def visit_ClassDef(self, node): self.writer.write('{} = make([]*πg.Object, {})'.format( bases.expr, len(node.bases))) for i, b in enumerate(node.bases): - with self.expr_visitor.visit(b) as b: + with self.visit_expr(b) as b: self.writer.write('{}[{}] = {}'.format(bases.expr, i, b.expr)) self.writer.write('{} = πg.NewDict()'.format(cls.name)) self.writer.write_checked_call2( mod_name, 'πF.Globals().GetItem(πF, {}.ToObject())', - self.block.intern('__name__')) + self.block.root.intern('__name__')) self.writer.write_checked_call1( '{}.SetItem(πF, {}.ToObject(), {})', - cls.expr, self.block.intern('__module__'), mod_name.expr) + cls.expr, self.block.root.intern('__module__'), mod_name.expr) tmpl = textwrap.dedent(""" _, πE = πg.NewCode($name, $filename, nil, 0, func(πF *πg.Frame, _ []*πg.Object) (*πg.Object, *πg.BaseException) { \tπClass := $cls \t_ = πClass""") self.writer.write_tmpl(tmpl, name=util.go_str(node.name), - filename=util.go_str(self.block.filename), + filename=util.go_str(self.block.root.filename), cls=cls.expr) with self.writer.indent_block(): self.writer.write_temp_decls(body_visitor.block) self.writer.write_block(body_visitor.block, - body_visitor.writer.out.getvalue()) + body_visitor.writer.getvalue()) + self.writer.write('return nil, nil') tmpl = textwrap.dedent("""\ }).Eval(πF, πF.Globals(), nil, nil) if πE != nil { - \treturn nil, πE + \tcontinue } if $meta, πE = $cls.GetItem(πF, $metaclass_str.ToObject()); πE != nil { - \treturn nil, πE + \tcontinue } if $meta == nil { \t$meta = πg.TypeType.ToObject() }""") - self.writer.write_tmpl(tmpl, meta=meta.name, cls=cls.expr, - metaclass_str=self.block.intern('__metaclass__')) + self.writer.write_tmpl( + tmpl, meta=meta.name, cls=cls.expr, + metaclass_str=self.block.root.intern('__metaclass__')) with self.block.alloc_temp() as type_: type_expr = ('{}.Call(πF, []*πg.Object{{πg.NewStr({}).ToObject(), ' 'πg.NewTuple({}...).ToObject(), {}.ToObject()}}, nil)') @@ -247,21 +168,21 @@ def visit_Continue(self, node): if not self.block.loop_stack: raise util.ParseError(node, "'continue' not in loop") self._write_py_context(node.lineno) - self.writer.write('goto Label{}'.format(self.block.top_loop().start_label)) + self.writer.write('continue') def visit_Delete(self, node): self._write_py_context(node.lineno) for target in node.targets: if isinstance(target, ast.Attribute): - with self.expr_visitor.visit(target.value) as t: + with self.visit_expr(target.value) as t: self.writer.write_checked_call1( - 'πg.DelAttr(πF, {}, {})', t.expr, self.block.intern(target.attr)) + 'πg.DelAttr(πF, {}, {})', t.expr, + self.block.root.intern(target.attr)) elif isinstance(target, ast.Name): self.block.del_var(self.writer, target.id) elif isinstance(target, ast.Subscript): - assert isinstance(target.ctx, ast.Del) - with self.expr_visitor.visit(target.value) as t,\ - self.expr_visitor.visit(target.slice) as index: + with self.visit_expr(target.value) as t,\ + self.visit_expr(target.slice) as index: self.writer.write_checked_call1('πg.DelItem(πF, {}, {})', t.expr, index.expr) else: @@ -270,54 +191,41 @@ def visit_Delete(self, node): def visit_Expr(self, node): self._write_py_context(node.lineno) - self.expr_visitor.visit(node.value).free() + self.visit_expr(node.value).free() def visit_For(self, node): - loop = self.block.push_loop() - orelse_label = self.block.genlabel() if node.orelse else loop.end_label - self._write_py_context(node.lineno) - with self.expr_visitor.visit(node.iter) as iter_expr, \ - self.block.alloc_temp() as i, \ - self.block.alloc_temp() as n: - self.writer.write_checked_call2(i, 'πg.Iter(πF, {})', iter_expr.expr) - self.writer.write_label(loop.start_label) - tmpl = textwrap.dedent("""\ - if $n, πE = πg.Next(πF, $i); πE != nil { - \tisStop, exc := πg.IsInstance(πF, πE.ToObject(), πg.StopIterationType.ToObject()) - \tif exc != nil { - \t\tπE = exc - \t\tcontinue - \t} - \tif !isStop { - \t\tcontinue - \t} - \tπE = nil - \tπF.RestoreExc(nil, nil) - \tgoto Label$orelse - }""") - self.writer.write_tmpl(tmpl, n=n.name, i=i.expr, orelse=orelse_label) - self._tie_target(node.target, n.expr) - self._visit_each(node.body) - self.writer.write('goto Label{}'.format(loop.start_label)) - - self.block.pop_loop() - if node.orelse: - self.writer.write_label(orelse_label) - self._visit_each(node.orelse) - # Avoid label "defined and not used" in case there's no break statements. - self.writer.write('goto Label{}'.format(loop.end_label)) - self.writer.write_label(loop.end_label) + with self.block.alloc_temp() as i: + with self.visit_expr(node.iter) as iter_expr: + self.writer.write_checked_call2(i, 'πg.Iter(πF, {})', iter_expr.expr) + def testfunc(testvar): + with self.block.alloc_temp() as n: + self.writer.write_tmpl(textwrap.dedent("""\ + if $n, πE = πg.Next(πF, $i); πE != nil { + \tisStop, exc := πg.IsInstance(πF, πE.ToObject(), πg.StopIterationType.ToObject()) + \tif exc != nil { + \t\tπE = exc + \t} else if isStop { + \t\tπE = nil + \t\tπF.RestoreExc(nil, nil) + \t} + \t$testvar = !isStop + } else { + \t$testvar = true"""), n=n.name, i=i.expr, testvar=testvar.name) + with self.writer.indent_block(): + self._tie_target(node.target, n.expr) + self.writer.write('}') + self._visit_loop(testfunc, node) def visit_FunctionDef(self, node): self._write_py_context(node.lineno + len(node.decorator_list)) - func = self.expr_visitor.visit_function_inline(node) + func = self.visit_function_inline(node) self.block.bind_var(self.writer, node.name, func.expr) while node.decorator_list: decorator = node.decorator_list.pop() - wrapped = ast.Name(node.name, ast.Load) - decorated = ast.Call(decorator, [wrapped], [], None, None) - target = ast.Assign([wrapped], decorated) - target.lineno = node.lineno + len(node.decorator_list) + wrapped = ast.Name(id=node.name) + decorated = ast.Call(func=decorator, args=[wrapped], keywords=[], + starargs=None, kwargs=None) + target = ast.Assign(targets=[wrapped], value=decorated, loc=node.loc) self.visit_Assign(target) def visit_Global(self, node): @@ -334,7 +242,7 @@ def visit_If(self, node): orelse = [node] while len(orelse) == 1 and isinstance(orelse[0], ast.If): ifnode = orelse[0] - with self.expr_visitor.visit(ifnode.test) as cond: + with self.visit_expr(ifnode.test) as cond: label = self.block.genlabel() # We goto the body of the if statement instead of executing it inline # because the body itself may be a goto target and Go does not support @@ -342,7 +250,7 @@ def visit_If(self, node): with self.block.alloc_temp('bool') as is_true: self.writer.write_tmpl(textwrap.dedent("""\ if $is_true, πE = πg.IsTrue(πF, $cond); πE != nil { - \treturn nil, πE + \tcontinue } if $is_true { \tgoto Label$label @@ -366,58 +274,17 @@ def visit_If(self, node): def visit_Import(self, node): self._write_py_context(node.lineno) - for alias in node.names: - if alias.name.startswith(_NATIVE_MODULE_PREFIX): - raise util.ParseError( - node, 'for native imports use "from __go__.xyz import ..." syntax') - with self._import(alias.name, 0) as mod: - asname = alias.asname or alias.name.split('.')[0] - self.block.bind_var(self.writer, asname, mod.expr) + for imp in self.block.root.importer.visit(node): + self._import_and_bind(imp) def visit_ImportFrom(self, node): - # Wildcard imports are not yet supported. - for alias in node.names: - if alias.name == '*': - msg = 'wildcard member import is not implemented: from %s import %s' % ( - node.module, alias.name) - raise util.ParseError(node, msg) self._write_py_context(node.lineno) - if node.module.startswith(_NATIVE_MODULE_PREFIX): - values = [alias.name for alias in node.names] - with self._import_native(node.module, values) as mod: - for alias in node.names: - # Strip the 'type_' prefix when populating the module. This means - # that, e.g. 'from __go__.foo import type_Bar' will populate foo with - # a member called Bar, not type_Bar (although the symbol in the - # importing module will still be type_Bar unless aliased). This bends - # the semantics of import but makes native module contents more - # sensible. - name = alias.name - if name.startswith(_NATIVE_TYPE_PREFIX): - name = name[len(_NATIVE_TYPE_PREFIX):] - with self.block.alloc_temp() as member: - self.writer.write_checked_call2( - member, 'πg.GetAttr(πF, {}, {}, nil)', - mod.expr, self.block.intern(name)) - self.block.bind_var( - self.writer, alias.asname or alias.name, member.expr) - elif node.module == '__future__': - # At this stage all future imports are done in an initial pass (see - # visit() above), so if they are encountered here after the last valid - # __future__ then it's a syntax error. - if node.lineno > self.future_features.future_lineno: - raise util.ParseError(node, late_future) - else: - # NOTE: Assume that the names being imported are all modules within a - # package. E.g. "from a.b import c" is importing the module c from package - # a.b, not some member of module b. We cannot distinguish between these - # two cases at compile time and the Google style guide forbids the latter - # so we support that use case only. - for alias in node.names: - name = '{}.{}'.format(node.module, alias.name) - with self._import(name, name.count('.')) as mod: - asname = alias.asname or alias.name - self.block.bind_var(self.writer, asname, mod.expr) + + if node.module == '__future__' and node != self.future_node: + raise util.LateFutureError(node) + + for imp in self.block.root.importer.visit(node): + self._import_and_bind(imp) def visit_Module(self, node): self._visit_each(node.body) @@ -426,24 +293,24 @@ def visit_Pass(self, node): self._write_py_context(node.lineno) def visit_Print(self, node): - if self.future_features.parser_flags & FUTURE_PRINT_FUNCTION: + if self.block.root.future_features.print_function: raise util.ParseError(node, 'syntax error (print is not a keyword)') self._write_py_context(node.lineno) with self.block.alloc_temp('[]*πg.Object') as args: self.writer.write('{} = make([]*πg.Object, {})'.format( args.expr, len(node.values))) for i, v in enumerate(node.values): - with self.expr_visitor.visit(v) as arg: + with self.visit_expr(v) as arg: self.writer.write('{}[{}] = {}'.format(args.expr, i, arg.expr)) self.writer.write_checked_call1('πg.Print(πF, {}, {})', args.expr, 'true' if node.nl else 'false') def visit_Raise(self, node): - with self.expr_visitor.visit(node.type) if node.type else _nil_expr as t,\ - self.expr_visitor.visit(node.inst) if node.inst else _nil_expr as inst,\ - self.expr_visitor.visit(node.tback) if node.tback else _nil_expr as tb: + with self.visit_expr(node.exc) if node.exc else _nil_expr as t,\ + self.visit_expr(node.inst) if node.inst else _nil_expr as inst,\ + self.visit_expr(node.tback) if node.tback else _nil_expr as tb: if node.inst: - assert node.type, 'raise had inst but no type' + assert node.exc, 'raise had inst but no type' if node.tback: assert node.inst, 'raise had tback but no inst' self._write_py_context(node.lineno) @@ -457,132 +324,105 @@ def visit_Return(self, node): if self.block.is_generator and node.value: raise util.ParseError(node, 'returning a value in a generator function') if node.value: - with self.expr_visitor.visit(node.value) as value: - self.writer.write('return {}, nil'.format(value.expr)) + with self.visit_expr(node.value) as value: + self.writer.write('πR = {}'.format(value.expr)) else: - self.writer.write('return nil, nil') + self.writer.write('πR = πg.None') + self.writer.write('continue') - def visit_TryExcept(self, node): # pylint: disable=g-doc-args + def visit_Try(self, node): # The general structure generated by this method is shown below: # # checkpoints.Push(Except) # # Checkpoints.Pop() # - # goto Done + # goto Finally # Except: # # Handler1: # - # goto Done + # Checkpoints.Pop() // Finally + # goto Finally # Handler2: # - # goto Done + # Checkpoints.Pop() // Finally + # goto Finally # ... - # Done: + # Finally: + # # # The dispatch table maps the current exception to the appropriate handler # label according to the exception clauses. # Write the try body. self._write_py_context(node.lineno) - except_label = self.block.genlabel(is_checkpoint=True) - done_label = self.block.genlabel() - self.writer.write('πF.PushCheckpoint({})'.format(except_label)) + finally_label = self.block.genlabel(is_checkpoint=bool(node.finalbody)) + if node.finalbody: + self.writer.write('πF.PushCheckpoint({})'.format(finally_label)) + except_label = None + if node.handlers: + except_label = self.block.genlabel(is_checkpoint=True) + self.writer.write('πF.PushCheckpoint({})'.format(except_label)) self._visit_each(node.body) - self.writer.write('πF.PopCheckpoint()') + if except_label: + self.writer.write('πF.PopCheckpoint()') # except_label if node.orelse: self._visit_each(node.orelse) - self.writer.write('goto Label{}'.format(done_label)) + if node.finalbody: + self.writer.write('πF.PopCheckpoint()') # finally_label + self.writer.write('goto Label{}'.format(finally_label)) with self.block.alloc_temp('*πg.BaseException') as exc: - if (len(node.handlers) == 1 and not node.handlers[0].type and - not node.orelse): - # When there's just a bare except, no dispatch is required. - self._write_except_block(except_label, exc.expr, node.handlers[0]) - self.writer.write_label(done_label) - return - - with self.block.alloc_temp('*πg.Traceback') as tb: - self.writer.write_label(except_label) - self.writer.write('{}, {} = πF.ExcInfo()'.format(exc.expr, tb.expr)) - handler_labels = self._write_except_dispatcher( - exc.expr, tb.expr, node.handlers) - - # Write the bodies of each of the except handlers. - for handler_label, except_node in zip(handler_labels, node.handlers): - self._write_except_block(handler_label, exc.expr, except_node) - self.writer.write('goto Label{}'.format(done_label)) - - self.writer.write_label(done_label) - - def visit_TryFinally(self, node): # pylint: disable=g-doc-args - # The general structure generated by this method is shown below: - # - # Checkpoints.Push(Finally) - # - # Checkpoints.Pop() - # Finally: - # - - # Write the try body. - self._write_py_context(node.lineno) - finally_label = self.block.genlabel(is_checkpoint=True) - self.writer.write('πF.PushCheckpoint({})'.format(finally_label)) - self._visit_each(node.body) - self.writer.write('πF.PopCheckpoint()') - - # Write the finally body. - with self.block.alloc_temp('*πg.BaseException') as exc,\ - self.block.alloc_temp('*πg.Traceback') as tb: + if except_label: + with self.block.alloc_temp('*πg.Traceback') as tb: + self.writer.write_label(except_label) + self.writer.write_tmpl(textwrap.dedent("""\ + if πE == nil { + continue + } + πE = nil + $exc, $tb = πF.ExcInfo()"""), exc=exc.expr, tb=tb.expr) + handler_labels = self._write_except_dispatcher( + exc.expr, tb.expr, node.handlers) + + # Write the bodies of each of the except handlers. + for handler_label, except_node in zip(handler_labels, node.handlers): + self._write_except_block(handler_label, exc.expr, except_node) + if node.finalbody: + self.writer.write('πF.PopCheckpoint()') # finally_label + self.writer.write('goto Label{}'.format(finally_label)) + + # Write the finally body. self.writer.write_label(finally_label) - self.writer.write('πE = nil') - self.writer.write('{}, {} = πF.RestoreExc(nil, nil)'.format( - exc.expr, tb.expr)) - self._visit_each(node.finalbody) - self.writer.write_tmpl(textwrap.dedent("""\ - if $exc != nil { - \tπE = πF.Raise($exc.ToObject(), nil, $tb.ToObject()) - \tcontinue - }"""), exc=exc.expr, tb=tb.expr) + if node.finalbody: + with self.block.alloc_temp('*πg.Traceback') as tb: + self.writer.write('{}, {} = πF.RestoreExc(nil, nil)'.format( + exc.expr, tb.expr)) + self._visit_each(node.finalbody) + self.writer.write_tmpl(textwrap.dedent("""\ + if $exc != nil { + \tπE = πF.Raise($exc.ToObject(), nil, $tb.ToObject()) + \tcontinue + } + if πR != nil { + \tcontinue + }"""), exc=exc.expr, tb=tb.expr) def visit_While(self, node): - loop = self.block.push_loop() self._write_py_context(node.lineno) - self.writer.write_label(loop.start_label) - orelse_label = self.block.genlabel() if node.orelse else loop.end_label - with self.expr_visitor.visit(node.test) as cond,\ - self.block.alloc_temp('bool') as is_true: - self.writer.write_checked_call2(is_true, 'πg.IsTrue(πF, {})', cond.expr) - self.writer.write_tmpl(textwrap.dedent("""\ - if !$is_true { - \tgoto Label$orelse_label - }"""), is_true=is_true.expr, orelse_label=orelse_label) - self._visit_each(node.body) - self.writer.write('goto Label{}'.format(loop.start_label)) - if node.orelse: - self.writer.write_label(orelse_label) - self._visit_each(node.orelse) - # Avoid label "defined and not used" in case there's no break statements. - self.writer.write('goto Label{}'.format(loop.end_label)) - self.writer.write_label(loop.end_label) - self.block.pop_loop() - - _AUG_ASSIGN_TEMPLATES = { - ast.Add: 'πg.IAdd(πF, {lhs}, {rhs})', - ast.BitAnd: 'πg.IAnd(πF, {lhs}, {rhs})', - ast.Div: 'πg.IDiv(πF, {lhs}, {rhs})', - ast.Mod: 'πg.IMod(πF, {lhs}, {rhs})', - ast.Mult: 'πg.IMul(πF, {lhs}, {rhs})', - ast.BitOr: 'πg.IOr(πF, {lhs}, {rhs})', - ast.Sub: 'πg.ISub(πF, {lhs}, {rhs})', - ast.BitXor: 'πg.IXor(πF, {lhs}, {rhs})', - } + def testfunc(testvar): + with self.visit_expr(node.test) as cond: + self.writer.write_checked_call2( + testvar, 'πg.IsTrue(πF, {})', cond.expr) + self._visit_loop(testfunc, node) def visit_With(self, node): - self._write_py_context(node.lineno) + assert len(node.items) == 1, 'multiple items in a with not yet supported' + item = node.items[0] + self._write_py_context(node.loc.line()) # mgr := EXPR - with self.expr_visitor.visit(node.context_expr) as mgr,\ + with self.visit_expr(item.context_expr) as mgr,\ self.block.alloc_temp() as exit_func,\ self.block.alloc_temp() as value: # The code here has a subtle twist: It gets the exit function attribute @@ -595,19 +435,19 @@ def visit_With(self, node): # exit := type(mgr).__exit__ self.writer.write_checked_call2( exit_func, 'πg.GetAttr(πF, {}.Type().ToObject(), {}, nil)', - mgr.expr, self.block.intern('__exit__')) + mgr.expr, self.block.root.intern('__exit__')) # value := type(mgr).__enter__(mgr) self.writer.write_checked_call2( value, 'πg.GetAttr(πF, {}.Type().ToObject(), {}, nil)', - mgr.expr, self.block.intern('__enter__')) + mgr.expr, self.block.root.intern('__enter__')) self.writer.write_checked_call2( value, '{}.Call(πF, πg.Args{{{}}}, nil)', value.expr, mgr.expr) finally_label = self.block.genlabel(is_checkpoint=True) self.writer.write('πF.PushCheckpoint({})'.format(finally_label)) - if node.optional_vars: - self._tie_target(node.optional_vars, value.expr) + if item.optional_vars: + self._tie_target(item.optional_vars, value.expr) self._visit_each(node.body) self.writer.write('πF.PopCheckpoint()') self.writer.write_label(finally_label) @@ -619,7 +459,10 @@ def visit_With(self, node): self.block.alloc_temp('*πg.Type') as t: # temp := exit(mgr, *sys.exec_info()) tmpl = """\ - $exc, $tb = πF.ExcInfo() + $exc, $tb = nil, nil + if πE != nil { + \t$exc, $tb = πF.ExcInfo() + } if $exc != nil { \t$t = $exc.Type() \tif $swallow_exc, πE = $exit_func.Call(πF, πg.Args{$mgr, $t.ToObject(), $exc.ToObject(), $tb.ToObject()}, nil); πE != nil { @@ -645,21 +488,109 @@ def visit_With(self, node): if $exc != nil && $swallow_exc != true { \tπE = πF.Raise(nil, nil, nil) \tcontinue + } + if πR != nil { + \tcontinue }"""), exc=exc.expr, swallow_exc=swallow_exc_bool.expr) + def visit_function_inline(self, node): + """Returns an GeneratedExpr for a function with the given body.""" + # First pass collects the names of locals used in this function. Do this in + # a separate pass so that we know whether to resolve a name as a local or a + # global during the second pass. + func_visitor = block.FunctionBlockVisitor(node) + for child in node.body: + func_visitor.visit(child) + func_block = block.FunctionBlock(self.block, node.name, func_visitor.vars, + func_visitor.is_generator) + visitor = StatementVisitor(func_block, self.future_node) + # Indent so that the function body is aligned with the goto labels. + with visitor.writer.indent_block(): + visitor._visit_each(node.body) # pylint: disable=protected-access + + result = self.block.alloc_temp() + with self.block.alloc_temp('[]πg.Param') as func_args: + args = node.args + argc = len(args.args) + self.writer.write('{} = make([]πg.Param, {})'.format( + func_args.expr, argc)) + # The list of defaults only contains args for which a default value is + # specified so pad it with None to make it the same length as args. + defaults = [None] * (argc - len(args.defaults)) + args.defaults + for i, (a, d) in enumerate(zip(args.args, defaults)): + with self.visit_expr(d) if d else expr.nil_expr as default: + tmpl = '$args[$i] = πg.Param{Name: $name, Def: $default}' + self.writer.write_tmpl(tmpl, args=func_args.expr, i=i, + name=util.go_str(a.arg), default=default.expr) + flags = [] + if args.vararg: + flags.append('πg.CodeFlagVarArg') + if args.kwarg: + flags.append('πg.CodeFlagKWArg') + # The function object gets written to a temporary writer because we need + # it as an expression that we subsequently bind to some variable. + self.writer.write_tmpl( + '$result = πg.NewFunction(πg.NewCode($name, $filename, $args, ' + '$flags, func(πF *πg.Frame, πArgs []*πg.Object) ' + '(*πg.Object, *πg.BaseException) {', + result=result.name, name=util.go_str(node.name), + filename=util.go_str(self.block.root.filename), args=func_args.expr, + flags=' | '.join(flags) if flags else 0) + with self.writer.indent_block(): + for var in func_block.vars.values(): + if var.type != block.Var.TYPE_GLOBAL: + fmt = 'var {0} *πg.Object = {1}; _ = {0}' + self.writer.write(fmt.format( + util.adjust_local_name(var.name), var.init_expr)) + self.writer.write_temp_decls(func_block) + self.writer.write('var πR *πg.Object; _ = πR') + self.writer.write('var πE *πg.BaseException; _ = πE') + if func_block.is_generator: + self.writer.write( + 'return πg.NewGenerator(πF, func(πSent *πg.Object) ' + '(*πg.Object, *πg.BaseException) {') + with self.writer.indent_block(): + self.writer.write_block(func_block, visitor.writer.getvalue()) + self.writer.write('return nil, πE') + self.writer.write('}).ToObject(), nil') + else: + self.writer.write_block(func_block, visitor.writer.getvalue()) + self.writer.write(textwrap.dedent("""\ + if πE != nil { + \tπR = nil + } else if πR == nil { + \tπR = πg.None + } + return πR, πE""")) + self.writer.write('}), πF.Globals()).ToObject()') + return result + + _AUG_ASSIGN_TEMPLATES = { + ast.Add: 'πg.IAdd(πF, {lhs}, {rhs})', + ast.BitAnd: 'πg.IAnd(πF, {lhs}, {rhs})', + ast.Div: 'πg.IDiv(πF, {lhs}, {rhs})', + ast.FloorDiv: 'πg.IFloorDiv(πF, {lhs}, {rhs})', + ast.LShift: 'πg.ILShift(πF, {lhs}, {rhs})', + ast.Mod: 'πg.IMod(πF, {lhs}, {rhs})', + ast.Mult: 'πg.IMul(πF, {lhs}, {rhs})', + ast.BitOr: 'πg.IOr(πF, {lhs}, {rhs})', + ast.Pow: 'πg.IPow(πF, {lhs}, {rhs})', + ast.RShift: 'πg.IRShift(πF, {lhs}, {rhs})', + ast.Sub: 'πg.ISub(πF, {lhs}, {rhs})', + ast.BitXor: 'πg.IXor(πF, {lhs}, {rhs})', + } + def _assign_target(self, target, value): if isinstance(target, ast.Name): self.block.bind_var(self.writer, target.id, value) elif isinstance(target, ast.Attribute): - assert isinstance(target.ctx, ast.Store) - with self.expr_visitor.visit(target.value) as obj: + with self.visit_expr(target.value) as obj: self.writer.write_checked_call1( 'πg.SetAttr(πF, {}, {}, {})', obj.expr, - self.block.intern(target.attr), value) + self.block.root.intern(target.attr), value) elif isinstance(target, ast.Subscript): - assert isinstance(target.ctx, ast.Store) - with self.expr_visitor.visit(target.value) as mapping,\ - self.expr_visitor.visit(target.slice) as index: + with self.visit_expr(target.value) as mapping,\ + self.visit_expr(target.slice) as index: self.writer.write_checked_call1('πg.SetItem(πF, {}, {}, {})', mapping.expr, index.expr, value) else: @@ -678,74 +609,35 @@ def _build_assign_target(self, target, assigns): tmpl = 'πg.TieTarget{Target: &$temp}' return string.Template(tmpl).substitute(temp=temp.name) - def _import(self, name, index): - """Returns an expression for a Module object returned from ImportModule. + def _import_and_bind(self, imp): + """Generates code that imports a module and binds it to a variable. Args: - name: The fully qualified Python module name, e.g. foo.bar. - index: The element in the list of modules that this expression should - select. E.g. for 'foo.bar', 0 corresponds to the package foo and 1 - corresponds to the module bar. - Returns: - A Go expression evaluating to an *Object (upcast from a *Module.) + imp: Import object representing an import of the form "import x.y.z" or + "from x.y import z". Expects only a single binding. """ - parts = name.split('.') - code_objs = [] - for i in xrange(len(parts)): - package_name = '/'.join(parts[:i + 1]) - if package_name != self.block.full_package_name: - package = self.block.add_import(package_name) - code_objs.append('{}.Code'.format(package.alias)) - else: - code_objs.append('Code') - mod = self.block.alloc_temp() - with self.block.alloc_temp('[]*πg.Object') as mod_slice: - handles_expr = '[]*πg.Code{' + ', '.join(code_objs) + '}' + # Acquire handles to the Code objects in each Go package and call + # ImportModule to initialize all modules. + with self.block.alloc_temp() as mod, \ + self.block.alloc_temp('[]*πg.Object') as mod_slice: self.writer.write_checked_call2( - mod_slice, 'πg.ImportModule(πF, {}, {})', - util.go_str(name), handles_expr) - self.writer.write('{} = {}[{}]'.format(mod.name, mod_slice.expr, index)) - return mod - - def _import_native(self, name, values): - reflect_package = self.block.add_native_import('reflect') - import_name = name[len(_NATIVE_MODULE_PREFIX):] - # Work-around for importing go module from VCS - # TODO: support bzr|git|hg|svn from any server - package_name = None - for x in _KNOWN_VCS: - if import_name.startswith(x): - package_name = x + import_name[len(x):].replace('.', '/') - break - if not package_name: - package_name = import_name.replace('.', '/') - - package = self.block.add_native_import(package_name) - mod = self.block.alloc_temp() - with self.block.alloc_temp('map[string]*πg.Object') as members: - self.writer.write_tmpl('$members = map[string]*πg.Object{}', - members=members.name) - for v in values: - module_attr = v - with self.block.alloc_temp() as wrapped: - if v.startswith(_NATIVE_TYPE_PREFIX): - module_attr = v[len(_NATIVE_TYPE_PREFIX):] - with self.block.alloc_temp( - '{}.{}'.format(package.alias, module_attr)) as type_: - self.writer.write_checked_call2( - wrapped, 'πg.WrapNative(πF, {}.ValueOf({}))', - reflect_package.alias, type_.expr) - self.writer.write('{} = {}.Type().ToObject()'.format( - wrapped.name, wrapped.expr)) - else: + mod_slice, 'πg.ImportModule(πF, {})', util.go_str(imp.name)) + + # Bind the imported modules or members to variables in the current scope. + for binding in imp.bindings: + if binding.bind_type == imputil.Import.MODULE: + self.writer.write('{} = {}[{}]'.format( + mod.name, mod_slice.expr, binding.value)) + self.block.bind_var(self.writer, binding.alias, mod.expr) + else: + self.writer.write('{} = {}[{}]'.format( + mod.name, mod_slice.expr, imp.name.count('.'))) + # Binding a member of the imported module. + with self.block.alloc_temp() as member: self.writer.write_checked_call2( - wrapped, 'πg.WrapNative(πF, {}.ValueOf({}.{}))', - reflect_package.alias, package.alias, v) - self.writer.write('{}[{}] = {}'.format( - members.name, util.go_str(module_attr), wrapped.expr)) - self.writer.write_checked_call2(mod, 'πg.ImportNativeModule(πF, {}, {})', - util.go_str(name), members.expr) - return mod + member, 'πg.GetAttr(πF, {}, {}, nil)', + mod.expr, self.block.root.intern(binding.value)) + self.block.bind_var(self.writer, binding.alias, member.expr) def _tie_target(self, target, value): if isinstance(target, ast.Name): @@ -764,6 +656,44 @@ def _visit_each(self, nodes): for node in nodes: self.visit(node) + def _visit_loop(self, testfunc, node): + start_label = self.block.genlabel(is_checkpoint=True) + else_label = self.block.genlabel(is_checkpoint=True) + end_label = self.block.genlabel() + with self.block.alloc_temp('bool') as breakvar: + self.block.push_loop(breakvar) + self.writer.write('πF.PushCheckpoint({})'.format(else_label)) + self.writer.write('{} = false'.format(breakvar.name)) + self.writer.write_label(start_label) + self.writer.write_tmpl(textwrap.dedent("""\ + if πE != nil || πR != nil { + \tcontinue + } + if $breakvar { + \tπF.PopCheckpoint() + \tgoto Label$end_label + }"""), breakvar=breakvar.expr, end_label=end_label) + with self.block.alloc_temp('bool') as testvar: + testfunc(testvar) + self.writer.write_tmpl(textwrap.dedent("""\ + if πE != nil || !$testvar { + \tcontinue + } + πF.PushCheckpoint($start_label)\ + """), testvar=testvar.name, start_label=start_label) + self._visit_each(node.body) + self.writer.write('continue') + # End the loop so that break applies to an outer loop if present. + self.block.pop_loop() + self.writer.write_label(else_label) + self.writer.write(textwrap.dedent("""\ + if πE != nil || πR != nil { + \tcontinue + }""")) + if node.orelse: + self._visit_each(node.orelse) + self.writer.write_label(end_label) + def _write_except_block(self, label, exc, except_node): self._write_py_context(except_node.lineno) self.writer.write_label(label) @@ -771,7 +701,6 @@ def _write_except_block(self, label, exc, except_node): self.block.bind_var(self.writer, except_node.name.id, '{}.ToObject()'.format(exc)) self._visit_each(except_node.body) - self.writer.write('πE = nil') self.writer.write('πF.RestoreExc(nil, nil)') def _write_except_dispatcher(self, exc, tb, handlers): @@ -792,7 +721,7 @@ def _write_except_dispatcher(self, exc, tb, handlers): for i, except_node in enumerate(handlers): handler_labels.append(self.block.genlabel()) if except_node.type: - with self.expr_visitor.visit(except_node.type) as type_,\ + with self.visit_expr(except_node.type) as type_,\ self.block.alloc_temp('bool') as is_inst: self.writer.write_checked_call2( is_inst, 'πg.IsInstance(πF, {}.ToObject(), {})', exc, type_.expr) @@ -815,6 +744,6 @@ def _write_except_dispatcher(self, exc, tb, handlers): def _write_py_context(self, lineno): if lineno: - line = self.block.lines[lineno - 1].strip() + line = self.block.root.buffer.source_line(lineno).strip() self.writer.write('// line {}: {}'.format(lineno, line)) self.writer.write('πF.SetLineno({})'.format(lineno)) diff --git a/compiler/stmt_test.py b/compiler/stmt_test.py index 99cd3c21..5f4cb0f5 100644 --- a/compiler/stmt_test.py +++ b/compiler/stmt_test.py @@ -16,16 +16,20 @@ """Tests for StatementVisitor.""" -import ast +from __future__ import unicode_literals + import re import subprocess import textwrap import unittest from grumpy.compiler import block +from grumpy.compiler import imputil from grumpy.compiler import shard_test from grumpy.compiler import stmt from grumpy.compiler import util +from grumpy import pythonparser +from grumpy.pythonparser import ast class StatementVisitorTest(unittest.TestCase): @@ -99,10 +103,11 @@ def testAugAssignBitAnd(self): foo &= 3 print foo"""))) - def testAugAssignUnsupportedOp(self): - expected = 'augmented assignment op not implemented' - self.assertRaisesRegexp(util.ParseError, expected, - _ParseAndVisit, 'foo **= bar') + def testAugAssignPow(self): + self.assertEqual((0, '64\n'), _GrumpRun(textwrap.dedent("""\ + foo = 8 + foo **= 2 + print foo"""))) def testClassDef(self): self.assertEqual((0, "\n"), _GrumpRun(textwrap.dedent("""\ @@ -292,129 +297,56 @@ def testImport(self): import sys print type(sys.modules)"""))) + def testImportFutureLateRaises(self): + regexp = 'from __future__ imports must occur at the beginning of the file' + self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit, + 'foo = bar\nfrom __future__ import print_function') + + def testFutureUnicodeLiterals(self): + want = "u'foo'\n" + self.assertEqual((0, want), _GrumpRun(textwrap.dedent("""\ + from __future__ import unicode_literals + print repr('foo')"""))) + + def testImportMember(self): + self.assertEqual((0, "\n"), _GrumpRun(textwrap.dedent("""\ + from sys import modules + print type(modules)"""))) + def testImportConflictingPackage(self): self.assertEqual((0, ''), _GrumpRun(textwrap.dedent("""\ import time - from __go__.time import Now"""))) + from "__go__/time" import Now"""))) def testImportNative(self): self.assertEqual((0, '1 1000000000\n'), _GrumpRun(textwrap.dedent("""\ - from __go__.time import Nanosecond, Second + from "__go__/time" import Nanosecond, Second print Nanosecond, Second"""))) - def testImportGrump(self): + def testImportGrumpy(self): self.assertEqual((0, ''), _GrumpRun(textwrap.dedent("""\ - from __go__.grumpy import Assert + from "__go__/grumpy" import Assert Assert(__frame__(), True, 'bad')"""))) - def testImportNativeModuleRaises(self): - regexp = r'for native imports use "from __go__\.xyz import \.\.\." syntax' - self.assertRaisesRegexp(util.ParseError, regexp, _ParseAndVisit, - 'import __go__.foo') - def testImportNativeType(self): self.assertEqual((0, "\n"), _GrumpRun(textwrap.dedent("""\ - from __go__.time import type_Duration as Duration + from "__go__/time" import Duration print Duration"""))) + def testImportWildcardMemberRaises(self): + regexp = 'wildcard member import is not implemented' + self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit, + 'from foo import *') + self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit, + 'from "__go__/foo" import *') + def testPrintStatement(self): self.assertEqual((0, 'abc 123\nfoo bar\n'), _GrumpRun(textwrap.dedent("""\ print 'abc', print '123' print 'foo', 'bar'"""))) - def testImportFromFuture(self): - testcases = [ - ('from __future__ import print_function', stmt.FUTURE_PRINT_FUNCTION), - ('from __future__ import generators', 0), - ('from __future__ import generators, print_function', - stmt.FUTURE_PRINT_FUNCTION), - ] - - for i, tc in enumerate(testcases): - source, want_flags = tc - mod = ast.parse(textwrap.dedent(source)) - node = mod.body[0] - got = stmt.import_from_future(node) - msg = '#{}: want {}, got {}'.format(i, want_flags, got) - self.assertEqual(want_flags, got, msg=msg) - - def testImportFromFutureParseError(self): - testcases = [ - # NOTE: move this group to testImportFromFuture as they are implemented - # by grumpy - ('from __future__ import absolute_import', - r'future feature \w+ not yet implemented'), - ('from __future__ import division', - r'future feature \w+ not yet implemented'), - ('from __future__ import unicode_literals', - r'future feature \w+ not yet implemented'), - - ('from __future__ import braces', 'not a chance'), - ('from __future__ import nonexistant_feature', - r'future feature \w+ is not defined'), - ] - - for tc in testcases: - source, want_regexp = tc - mod = ast.parse(source) - node = mod.body[0] - self.assertRaisesRegexp(util.ParseError, want_regexp, - stmt.import_from_future, node) - - def testImportWildcardMemberRaises(self): - regexp = r'wildcard member import is not implemented: from foo import *' - self.assertRaisesRegexp(util.ParseError, regexp, _ParseAndVisit, - 'from foo import *') - regexp = (r'wildcard member import is not ' - r'implemented: from __go__.foo import *') - self.assertRaisesRegexp(util.ParseError, regexp, _ParseAndVisit, - 'from __go__.foo import *') - - def testVisitFuture(self): - testcases = [ - ('from __future__ import print_function', - stmt.FUTURE_PRINT_FUNCTION, 1), - ("""\ - "module docstring" - - from __future__ import print_function - """, stmt.FUTURE_PRINT_FUNCTION, 3), - ("""\ - "module docstring" - - from __future__ import print_function, with_statement - from __future__ import nested_scopes - """, stmt.FUTURE_PRINT_FUNCTION, 4), - ] - - for tc in testcases: - source, flags, lineno = tc - mod = ast.parse(textwrap.dedent(source)) - future_features = stmt.visit_future(mod) - self.assertEqual(future_features.parser_flags, flags) - self.assertEqual(future_features.future_lineno, lineno) - - def testVisitFutureParseError(self): - testcases = [ - # future after normal imports - """\ - import os - from __future__ import print_function - """, - # future after non-docstring expression - """ - asd = 123 - from __future__ import print_function - """ - ] - - for source in testcases: - mod = ast.parse(textwrap.dedent(source)) - self.assertRaisesRegexp(util.ParseError, stmt.late_future, - stmt.visit_future, mod) - - def testFutureFeaturePrintFunction(self): + def testPrintFunction(self): want = "abc\n123\nabc 123\nabcx123\nabc 123 " self.assertEqual((0, want), _GrumpRun(textwrap.dedent("""\ "module docstring is ok to proceed __future__" @@ -513,8 +445,8 @@ def testTryFinally(self): finally: print 'bar'""")) self.assertEqual(1, result[0]) - # Some platforms show "exit status 1" message so don't test strict equality. - self.assertIn('foo bar\nfoo bar\nException\n', result[1]) + self.assertIn('foo bar\nfoo bar\n', result[1]) + self.assertIn('Exception\n', result[1]) def testWhile(self): self.assertEqual((0, '2\n1\n'), _GrumpRun(textwrap.dedent("""\ @@ -573,7 +505,7 @@ def testWriteExceptDispatcherBareExcept(self): 'exc', 'tb', handlers), [1, 2]) expected = re.compile(r'ResolveGlobal\(.*foo.*\bIsInstance\(.*' r'goto Label1.*goto Label2', re.DOTALL) - self.assertRegexpMatches(visitor.writer.out.getvalue(), expected) + self.assertRegexpMatches(visitor.writer.getvalue(), expected) def testWriteExceptDispatcherBareExceptionNotLast(self): visitor = stmt.StatementVisitor(_MakeModuleBlock()) @@ -593,19 +525,20 @@ def testWriteExceptDispatcherMultipleExcept(self): r'ResolveGlobal\(.*foo.*\bif .*\bIsInstance\(.*\{.*goto Label1.*' r'ResolveGlobal\(.*bar.*\bif .*\bIsInstance\(.*\{.*goto Label2.*' r'\bRaise\(exc\.ToObject\(\), nil, tb\.ToObject\(\)\)', re.DOTALL) - self.assertRegexpMatches(visitor.writer.out.getvalue(), expected) + self.assertRegexpMatches(visitor.writer.getvalue(), expected) def _MakeModuleBlock(): - return block.ModuleBlock('__main__', 'grumpy', 'grumpy/lib', '', [], - stmt.FutureFeatures()) + return block.ModuleBlock(None, '__main__', '', '', + imputil.FutureFeatures()) def _ParseAndVisit(source): - mod = ast.parse(source) - future_features = stmt.visit_future(mod) - b = block.ModuleBlock('__main__', 'grumpy', 'grumpy/lib', '', - source.split('\n'), future_features) + mod = pythonparser.parse(source) + _, future_features = imputil.parse_future_features(mod) + importer = imputil.Importer(None, 'foo', 'foo.py', False) + b = block.ModuleBlock(importer, '__main__', '', + source, future_features) visitor = stmt.StatementVisitor(b) visitor.visit(mod) return visitor diff --git a/compiler/util.py b/compiler/util.py index 97c36257..e70ce527 100644 --- a/compiler/util.py +++ b/compiler/util.py @@ -16,9 +16,13 @@ """Utilities for generating Go code.""" +from __future__ import unicode_literals + +import codecs import contextlib import cStringIO import string +import StringIO import textwrap @@ -26,21 +30,44 @@ _ESCAPES = {'\t': r'\t', '\r': r'\r', '\n': r'\n', '"': r'\"', '\\': r'\\'} -class ParseError(Exception): +# This is the max length of a direct allocation tuple supported by the runtime. +# This should match the number of specializations found in tuple.go. +MAX_DIRECT_TUPLE = 6 + + +class CompileError(Exception): def __init__(self, node, msg): if hasattr(node, 'lineno'): msg = 'line {}: {}'.format(node.lineno, msg) - super(ParseError, self).__init__(msg) + super(CompileError, self).__init__(msg) + + +class ParseError(CompileError): + pass + + +class ImportError(CompileError): # pylint: disable=redefined-builtin + pass + + +class LateFutureError(ImportError): + + def __init__(self, node): + msg = 'from __future__ imports must occur at the beginning of the file' + super(LateFutureError, self).__init__(node, msg) class Writer(object): """Utility class for writing blocks of Go code to a file-like object.""" def __init__(self, out=None): - self.out = out or cStringIO.StringIO() + self.out = codecs.getwriter('utf8')(out or cStringIO.StringIO()) self.indent_level = 0 + def getvalue(self): + return self.out.getvalue().decode('utf8') + @contextlib.contextmanager def indent_block(self, n=1): """A context manager that indents by n on entry and dedents on exit.""" @@ -60,7 +87,6 @@ def write_block(self, block_, body): block_: The Block object representing the code block. body: String containing Go code making up the body of the code block. """ - self.write('var πE *πg.BaseException; _ = πE') self.write('for ; πF.State() >= 0; πF.PopCheckpoint() {') with self.indent_block(): self.write('switch πF.State() {') @@ -72,18 +98,7 @@ def write_block(self, block_, body): # Assume that body is aligned with goto labels. with self.indent_block(-1): self.write(body) - self.write('return nil, nil') self.write('}') - self.write('return nil, πE') - - def write_import_block(self, imports): - if not imports: - return - self.write('import (') - with self.indent_block(): - for name in sorted(imports): - self.write('{} "{}"'.format(imports[name].alias, name)) - self.write(')') def write_label(self, label): with self.indent_block(-1): @@ -121,7 +136,7 @@ def dedent(self, n=1): def go_str(value): """Returns value as a valid Go string literal.""" - io = cStringIO.StringIO() + io = StringIO.StringIO() io.write('"') for c in value: if c in _ESCAPES: diff --git a/compiler/util_test.py b/compiler/util_test.py index 2f63adc6..d129ffa8 100644 --- a/compiler/util_test.py +++ b/compiler/util_test.py @@ -16,11 +16,13 @@ """Tests Writer and other utils.""" +from __future__ import unicode_literals + import unittest from grumpy.compiler import block +from grumpy.compiler import imputil from grumpy.compiler import util -from grumpy.compiler import stmt class WriterTest(unittest.TestCase): @@ -31,63 +33,50 @@ def testIndentBlock(self): with writer.indent_block(n=2): writer.write('bar') writer.write('baz') - self.assertEqual(writer.out.getvalue(), 'foo\n\t\tbar\nbaz\n') + self.assertEqual(writer.getvalue(), 'foo\n\t\tbar\nbaz\n') def testWriteBlock(self): writer = util.Writer() - mod_block = block.ModuleBlock('__main__', 'grumpy', 'grumpy/lib', '', - [], stmt.FutureFeatures()) + mod_block = block.ModuleBlock(None, '__main__', '', '', + imputil.FutureFeatures()) writer.write_block(mod_block, 'BODY') - output = writer.out.getvalue() + output = writer.getvalue() dispatch = 'switch πF.State() {\n\tcase 0:\n\tdefault: panic' self.assertIn(dispatch, output) - self.assertIn('return nil, nil\n}', output) - - def testWriteImportBlockEmptyImports(self): - writer = util.Writer() - writer.write_import_block({}) - self.assertEqual(writer.out.getvalue(), '') - - def testWriteImportBlockImportsSorted(self): - writer = util.Writer() - imports = {name: block.Package(name) for name in ('a', 'b', 'c')} - writer.write_import_block(imports) - self.assertEqual(writer.out.getvalue(), - 'import (\n\tπ_a "a"\n\tπ_b "b"\n\tπ_c "c"\n)\n') def testWriteMultiline(self): writer = util.Writer() writer.indent(2) writer.write('foo\nbar\nbaz\n') - self.assertEqual(writer.out.getvalue(), '\t\tfoo\n\t\tbar\n\t\tbaz\n') + self.assertEqual(writer.getvalue(), '\t\tfoo\n\t\tbar\n\t\tbaz\n') def testWritePyContext(self): writer = util.Writer() writer.write_py_context(12, 'print "foo"') - self.assertEqual(writer.out.getvalue(), '// line 12: print "foo"\n') + self.assertEqual(writer.getvalue(), '// line 12: print "foo"\n') def testWriteSkipBlankLine(self): writer = util.Writer() writer.write('foo\n\nbar') - self.assertEqual(writer.out.getvalue(), 'foo\nbar\n') + self.assertEqual(writer.getvalue(), 'foo\nbar\n') def testWriteTmpl(self): writer = util.Writer() writer.write_tmpl('$foo, $bar\n$baz', foo=1, bar=2, baz=3) - self.assertEqual(writer.out.getvalue(), '1, 2\n3\n') + self.assertEqual(writer.getvalue(), '1, 2\n3\n') def testIndent(self): writer = util.Writer() writer.indent(2) writer.write('foo') - self.assertEqual(writer.out.getvalue(), '\t\tfoo\n') + self.assertEqual(writer.getvalue(), '\t\tfoo\n') def testDedent(self): writer = util.Writer() writer.indent(4) writer.dedent(3) writer.write('foo') - self.assertEqual(writer.out.getvalue(), '\tfoo\n') + self.assertEqual(writer.getvalue(), '\tfoo\n') if __name__ == '__main__': diff --git a/lib/__builtin__.py b/lib/__builtin__.py index 4d176a81..a8ef6080 100644 --- a/lib/__builtin__.py +++ b/lib/__builtin__.py @@ -16,7 +16,7 @@ # pylint: disable=invalid-name -from __go__.grumpy import Builtins +from '__go__/grumpy' import Builtins for k, v in Builtins.iteritems(): diff --git a/lib/_random.py b/lib/_random.py index 9569bece..09688a19 100644 --- a/lib/_random.py +++ b/lib/_random.py @@ -14,9 +14,9 @@ """Generate pseudo random numbers. Should not be used for security purposes.""" -from __go__.math.rand import Uint32, Seed -from __go__.math import Pow -from __go__.time import Now +from '__go__/math/rand' import Uint32, Seed +from '__go__/math' import Pow +from '__go__/time' import Now BPF = 53 # Number of bits in a float diff --git a/lib/_syscall.py b/lib/_syscall.py new file mode 100644 index 00000000..db18dde5 --- /dev/null +++ b/lib/_syscall.py @@ -0,0 +1,31 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from '__go__/syscall' import EINTR + + +def invoke(func, *args): + while True: + result = func(*args) + if isinstance(result, tuple): + err = result[-1] + result = result[:-1] + else: + err = result + result = () + if err: + if err == EINTR: + continue + raise OSError(err.Error()) + return result diff --git a/lib/cStringIO.py b/lib/cStringIO.py new file mode 120000 index 00000000..7b1e326c --- /dev/null +++ b/lib/cStringIO.py @@ -0,0 +1 @@ +../third_party/stdlib/StringIO.py \ No newline at end of file diff --git a/lib/exceptions.py b/lib/exceptions.py index e888481f..d28fed6c 100644 --- a/lib/exceptions.py +++ b/lib/exceptions.py @@ -14,7 +14,7 @@ """Built-in exception classes.""" -from __go__.grumpy import ExceptionTypes +from '__go__/grumpy' import ExceptionTypes g = globals() diff --git a/lib/math.py b/lib/math.py index 3870d2e7..06c837e7 100644 --- a/lib/math.py +++ b/lib/math.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __go__.math import (Pi, E, Ceil, Copysign, Abs, Floor, Mod, Frexp, IsInf, +from '__go__/math' import (Pi, E, Ceil, Copysign, Abs, Floor, Mod, Frexp, IsInf, IsNaN, Exp2, Modf, Trunc, Exp, Expm1, Log, Log1p, Log10, Pow, Sqrt, Acos, Asin, Atan, Atan2, Hypot, Sin, Cos, Tan, Acosh, Asinh, Atanh, Sinh, Cosh, Tanh, Erf, Erfc, Gamma, Lgamma) # pylint: disable=g-multiple-import diff --git a/lib/operator.py b/lib/operator.py deleted file mode 100644 index 68473758..00000000 --- a/lib/operator.py +++ /dev/null @@ -1,33 +0,0 @@ -def eq(a, b): - return a == b - - -def le(a, b): - return a <= b - - -def lt(a, b): - return a < b - - -def ge(a, b): - return a >= b - - -def gt(a, b): - return a > b - - -def itemgetter(*items): - if len(items) == 1: - item = items[0] - def g(obj): - return obj[item] - else: - def g(obj): - return tuple(obj[item] for item in items) - return g - - -def ne(a, b): - return a != b diff --git a/lib/os/__init__.py b/lib/os/__init__.py index dd897ae6..0325cc5c 100644 --- a/lib/os/__init__.py +++ b/lib/os/__init__.py @@ -15,19 +15,28 @@ """Miscellaneous operating system interfaces.""" # pylint: disable=g-multiple-import +from '__go__/io/ioutil' import ReadDir +from '__go__/os' import (Chdir, Chmod, Environ, Getpid as getpid, Getwd, Pipe, + ProcAttr, Remove, StartProcess, Stat, Stdout, Stdin, + Stderr, Mkdir) +from '__go__/path/filepath' import Separator +from '__go__/grumpy' import (NewFileFromFD, StartThread, ToNative) +from '__go__/reflect' import MakeSlice +from '__go__/runtime' import GOOS +from '__go__/syscall' import (Close, SYS_FCNTL, Syscall, F_GETFD, Wait4, + WaitStatus, WNOHANG) +from '__go__/sync' import WaitGroup +from '__go__/time' import Second +import _syscall from os import path import stat as stat_module import sys -from __go__.io.ioutil import ReadDir -from __go__.os import Chdir, Chmod, Environ, Getwd, Remove, Stat -from __go__.path.filepath import Separator -from __go__.grumpy import NewFileFromFD -from __go__.syscall import Close, SYS_FCNTL, Syscall, F_GETFD -from __go__.time import Second + sep = chr(Separator) error = OSError # pylint: disable=invalid-name -curdir = "." +curdir = '.' +name = 'posix' environ = {} @@ -36,6 +45,12 @@ environ[k] = v +def mkdir(path, mode=0o777): + err = Mkdir(path, mode) + if err: + raise OSError(err.Error()) + + def chdir(path): err = Chdir(path) if err: @@ -60,7 +75,7 @@ def fdopen(fd, mode='r'): # pylint: disable=unused-argument _, _, err = Syscall(SYS_FCNTL, fd, F_GETFD, 0) if err: raise OSError(err.Error()) - return NewFileFromFD(fd) + return NewFileFromFD(fd, None) def listdir(p): @@ -77,6 +92,63 @@ def getcwd(): return dir +class _Popen(object): + + def __init__(self, command, mode): + self.mode = mode + self.result = None + self.r, self.w, err = Pipe() + if err: + raise OSError(err.Error()) + attr = ProcAttr.new() + # Create a slice using a reflect.Type returned by ToNative. + # TODO: There should be a cleaner way to create slices in Python. + files_type = ToNative(__frame__(), attr.Files).Type() + files = MakeSlice(files_type, 3, 3).Interface() + if self.mode == 'r': + fd = self.r.Fd() + files[0], files[1], files[2] = Stdin, self.w, Stderr + elif self.mode == 'w': + fd = self.w.Fd() + files[0], files[1], files[2] = self.r, Stdout, Stderr + else: + raise ValueError('invalid popen mode: %r', self.mode) + attr.Files = files + # TODO: There should be a cleaner way to create slices in Python. + args_type = ToNative(__frame__(), StartProcess).Type().In(1) + args = MakeSlice(args_type, 3, 3).Interface() + shell = environ['SHELL'] + args[0] = shell + args[1] = '-c' + args[2] = command + self.proc, err = StartProcess(shell, args, attr) + if err: + raise OSError(err.Error()) + self.wg = WaitGroup.new() + self.wg.Add(1) + StartThread(self._thread_func) + self.file = NewFileFromFD(fd, self.close) + + def _thread_func(self): + self.result = self.proc.Wait() + if self.mode == 'r': + self.w.Close() + self.wg.Done() + + def close(self, _): + if self.mode == 'w': + self.w.Close() + self.wg.Wait() + state, err = self.result + if err: + raise OSError(err.Error()) + return state.Sys() + + +def popen(command, mode='r'): + return _Popen(command, mode).file + + def remove(filepath): if stat_module.S_ISDIR(stat(filepath).st_mode): raise OSError('Operation not permitted: ' + filepath) @@ -120,3 +192,16 @@ def stat(filepath): if err: raise OSError(err.Error()) return StatResult(info) + + +unlink = remove + + +def waitpid(pid, options): + status = WaitStatus.new() + _syscall.invoke(Wait4, pid, status, options, None) + return pid, _encode_wait_result(status) + + +def _encode_wait_result(status): + return status.Signal() | (status.ExitStatus() << 8) diff --git a/lib/os/path.py b/lib/os/path.py index 09f30269..fd0b419d 100644 --- a/lib/os/path.py +++ b/lib/os/path.py @@ -14,8 +14,8 @@ """"Utilities for manipulating and inspecting OS paths.""" -from __go__.os import Stat -from __go__.path.filepath import Abs, Base, Clean, Dir as dirname, IsAbs as isabs, Join, Split # pylint: disable=g-multiple-import,unused-import +from '__go__/os' import Stat +from '__go__/path/filepath' import Abs, Base, Clean, Dir as dirname, IsAbs as isabs, Join, Split # pylint: disable=g-multiple-import,unused-import def abspath(path): diff --git a/lib/os_test.py b/lib/os_test.py index 63c15b3c..20864f1c 100644 --- a/lib/os_test.py +++ b/lib/os_test.py @@ -101,6 +101,41 @@ def TestFDOpenOSError(): raise AssertionError +def TestMkdir(): + path = 'foobarqux' + try: + os.stat(path) + except OSError: + pass + else: + raise AssertionError + try: + os.mkdir(path) + assert stat.S_ISDIR(os.stat(path).st_mode) + except OSError: + raise AssertionError + finally: + os.rmdir(path) + + +def TestPopenRead(): + f = os.popen('qux') + assert f.close() == 32512 + f = os.popen('echo hello') + try: + assert f.read() == 'hello\n' + finally: + assert f.close() == 0 + + +def TestPopenWrite(): + # TODO: We should verify the output but there's no good way to swap out stdout + # at the moment. + f = os.popen('cat', 'w') + f.write('popen write\n') + f.close() + + def TestRemove(): fd, path = tempfile.mkstemp() os.close(fd) @@ -208,5 +243,12 @@ def TestStatNoExist(): os.rmdir(path) +def TestWaitPid(): + try: + pid, status = os.waitpid(-1, os.WNOHANG) + except OSError as e: + assert 'no child processes' in str(e).lower() + + if __name__ == '__main__': weetest.RunTests() diff --git a/lib/select_.py b/lib/select_.py new file mode 100644 index 00000000..1aa7788f --- /dev/null +++ b/lib/select_.py @@ -0,0 +1,87 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from '__go__/syscall' import ( + FD_SETSIZE as _FD_SETSIZE, + Select as _Select, + FdSet as _FdSet, + Timeval as _Timeval +) +import _syscall +import math + + +class error(Exception): + pass + + +def select(rlist, wlist, xlist, timeout=None): + rlist_norm = _normalize_fd_list(rlist) + wlist_norm = _normalize_fd_list(wlist) + xlist_norm = _normalize_fd_list(xlist) + all_fds = rlist_norm + wlist_norm + xlist_norm + if not all_fds: + nfd = 0 + else: + nfd = max(all_fds) + 1 + + rfds = _make_fdset(rlist_norm) + wfds = _make_fdset(wlist_norm) + xfds = _make_fdset(xlist_norm) + + if timeout is None: + timeval = None + else: + timeval = _Timeval.new() + frac, integer = math.modf(timeout) + timeval.Sec = int(integer) + timeval.Usec = int(frac * 1000000.0) + _syscall.invoke(_Select, nfd, rfds, wfds, xfds, timeval) + return ([rlist[i] for i, fd in enumerate(rlist_norm) if _fdset_isset(fd, rfds)], + [wlist[i] for i, fd in enumerate(wlist_norm) if _fdset_isset(fd, wfds)], + [xlist[i] for i, fd in enumerate(xlist_norm) if _fdset_isset(fd, xfds)]) + + +def _fdset_set(fd, fds): + idx = fd / (_FD_SETSIZE / len(fds.Bits)) % len(fds.Bits) + pos = fd % (_FD_SETSIZE / len(fds.Bits)) + fds.Bits[idx] |= 1 << pos + + +def _fdset_isset(fd, fds): + idx = fd / (_FD_SETSIZE / len(fds.Bits)) % len(fds.Bits) + pos = fd % (_FD_SETSIZE / len(fds.Bits)) + return bool(fds.Bits[idx] & (1 << pos)) + + +def _make_fdset(fd_list): + fds = _FdSet.new() + for fd in fd_list: + _fdset_set(fd, fds) + return fds + + +def _normalize_fd_list(fds): + result = [] + # Python permits mutating the select fds list during fileno calls so we can't + # just use simple iteration over the list. See test_select_mutated in + # test_select.py + i = 0 + while i < len(fds): + fd = fds[i] + if hasattr(fd, 'fileno'): + fd = fd.fileno() + result.append(fd) + i += 1 + return result diff --git a/lib/stat.py b/lib/stat.py index dc61229a..15f4f9a4 100644 --- a/lib/stat.py +++ b/lib/stat.py @@ -15,11 +15,11 @@ """Interpreting stat() results.""" # pylint: disable=g-multiple-import -from __go__.os import ModeDir, ModePerm +from '__go__/os' import ModeDir, ModePerm def S_ISDIR(mode): # pylint: disable=invalid-name - return mode & ModeDir + return mode & ModeDir != 0 def S_IMODE(mode): # pylint: disable=invalid-name diff --git a/lib/sys.py b/lib/sys.py index cfbde9ef..9eac9a82 100644 --- a/lib/sys.py +++ b/lib/sys.py @@ -14,10 +14,10 @@ """System-specific parameters and functions.""" -from __go__.os import Args, Stdin, Stdout, Stderr -from __go__.grumpy import SysModules, MaxInt, NewFileFromFD # pylint: disable=g-multiple-import -from __go__.runtime import Version -from __go__.unicode import MaxRune +from '__go__/os' import Args +from '__go__/grumpy' import SysModules, MaxInt, Stdin as stdin, Stdout as stdout, Stderr as stderr # pylint: disable=g-multiple-import +from '__go__/runtime' import (GOOS as platform, Version) +from '__go__/unicode' import MaxRune argv = [] for arg in Args: @@ -32,11 +32,7 @@ warnoptions = [] # TODO: Support actual byteorder byteorder = 'little' - -stdin = NewFileFromFD(Stdin.Fd()) -stdout = NewFileFromFD(Stdout.Fd()) -stderr = NewFileFromFD(Stderr.Fd()) - +version = '2.7.13' class _Flags(object): """Container class for sys.flags.""" @@ -61,6 +57,10 @@ class _Flags(object): flags = _Flags() +def exc_clear(): + __frame__().__exc_clear__() + + def exc_info(): e, tb = __frame__().__exc_info__() # pylint: disable=undefined-variable t = None @@ -71,3 +71,13 @@ def exc_info(): def exit(code=None): # pylint: disable=redefined-builtin raise SystemExit(code) + + +def _getframe(depth=0): + f = __frame__() + while depth > 0 and f is not None: + f = f.f_back + depth -= 1 + if f is None: + raise ValueError('call stack is not deep enough') + return f diff --git a/lib/sys_test.py b/lib/sys_test.py index d185c403..20a8db34 100644 --- a/lib/sys_test.py +++ b/lib/sys_test.py @@ -32,6 +32,17 @@ def TestSysModules(): assert sys.modules['sys'] is not None +def TestExcClear(): + try: + raise RuntimeError + except: + assert all(sys.exc_info()), sys.exc_info() + sys.exc_clear() + assert not any(sys.exc_info()) + else: + assert False + + def TestExcInfoNoException(): assert sys.exc_info() == (None, None, None) @@ -75,6 +86,23 @@ def TestExitInvalidArgs(): assert False +def TestGetFrame(): + try: + sys._getframe(42, 42) + except TypeError: + pass + else: + assert False + try: + sys._getframe(2000000000) + except ValueError: + pass + else: + assert False + assert sys._getframe().f_code.co_name == '_getframe' + assert sys._getframe(1).f_code.co_name == 'TestGetFrame' + + if __name__ == '__main__': # This call will incidentally test sys.exit(). weetest.RunTests() diff --git a/lib/tempfile.py b/lib/tempfile.py index 03a02535..16e3956c 100644 --- a/lib/tempfile.py +++ b/lib/tempfile.py @@ -15,8 +15,8 @@ """Generate temporary files and directories.""" # pylint: disable=g-multiple-import -from __go__.io.ioutil import TempDir, TempFile -from __go__.syscall import Dup +from '__go__/io/ioutil' import TempDir, TempFile +from '__go__/syscall' import Dup # pylint: disable=redefined-builtin diff --git a/lib/thread.py b/lib/thread.py index 30095230..38917339 100644 --- a/lib/thread.py +++ b/lib/thread.py @@ -1,5 +1,62 @@ +from '__go__/grumpy' import NewTryableMutex, StartThread, ThreadCount + + +class error(Exception): + pass + + def get_ident(): f = __frame__() while f.f_back: f = f.f_back return id(f) + + +class LockType(object): + def __init__(self): + self._mutex = NewTryableMutex() + + def acquire(self, waitflag=1): + if waitflag: + self._mutex.Lock() + return True + return self._mutex.TryLock() + + def release(self): + self._mutex.Unlock() + + def __enter__(self): + self.acquire() + + def __exit__(self, *args): + self.release() + + +def allocate_lock(): + """Dummy implementation of thread.allocate_lock().""" + return LockType() + + +def start_new_thread(func, args, kwargs=None): + if kwargs is None: + kwargs = {} + l = allocate_lock() + ident = [] + def thread_func(): + ident.append(get_ident()) + l.release() + func(*args, **kwargs) + l.acquire() + StartThread(thread_func) + l.acquire() + return ident[0] + + +def stack_size(n=0): + if n: + raise error('grumpy does not support setting stack size') + return 0 + + +def _count(): + return ThreadCount diff --git a/lib/threading.py b/lib/threading.py deleted file mode 100644 index 2204cdff..00000000 --- a/lib/threading.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2016 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Concurrent programming functionality.""" - -from __go__.grumpy import StartThread -from __go__.sync import NewCond, type_Mutex as Mutex - - -class Event(object): - """Event is a way to signal conditions between threads.""" - - def __init__(self): - self._mutex = Mutex.new() - self._cond = NewCond(self._mutex) - self._is_set = False - - def set(self): - self._mutex.Lock() - try: - self._is_set = True - finally: - self._mutex.Unlock() - self._cond.Broadcast() - - # TODO: Support timeout param. - def wait(self): - self._mutex.Lock() - try: - while not self._is_set: - self._cond.Wait() - finally: - self._mutex.Unlock() - return True - - -class Thread(object): - """Thread is an activity to be executed concurrently.""" - - def __init__(self, target=None, args=()): - self._target = target - self._args = args - self._event = Event() - - def run(self): - self._target(*self._args) - - def start(self): - StartThread(self._run) - - # TODO: Support timeout param. - def join(self): - self._event.wait() - - def _run(self): - try: - self.run() - finally: - self._event.set() diff --git a/lib/threading_test.py b/lib/threading_test.py deleted file mode 100644 index eb8b4e61..00000000 --- a/lib/threading_test.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2016 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import threading -import time - -import weetest - - -def TestEvent(): - e = threading.Event() - target_result = [] - x = 'not ready' - def Target(): - e.wait() - target_result.append(x) - t = threading.Thread(target=Target) - t.start() - # Sleeping gives us some confidence that t had the opportunity to wait on e - # and that if e is broken (e.g. wait() returned immediately) then the test - # will fail below. - time.sleep(0.1) - x = 'ready' - e.set() - t.join() - assert target_result == ['ready'] - - -def TestThread(): - ran = [] - def Target(): - ran.append(True) - t = threading.Thread(target=Target) - t.start() - t.join() - assert ran - - -def TestThreadArgs(): - target_args = [] - def Target(*args): - target_args.append(args) - t = threading.Thread(target=Target, args=('foo', 42)) - t.start() - t.join() - assert target_args == [('foo', 42)] - - -if __name__ == '__main__': - weetest.RunTests() diff --git a/lib/time.py b/lib/time.py index 3bc575e5..0eef7e22 100644 --- a/lib/time.py +++ b/lib/time.py @@ -14,7 +14,74 @@ """Time access and conversions.""" -from __go__.time import Now, Second, Sleep # pylint: disable=g-multiple-import +from '__go__/time' import Local, Now, Second, Sleep, Unix, Date, UTC # pylint: disable=g-multiple-import + + +_strftime_directive_map = { + '%': '%', + 'a': 'Mon', + 'A': 'Monday', + 'b': 'Jan', + 'B': 'January', + 'c': NotImplemented, + 'd': '02', + 'H': '15', + 'I': '03', + 'j': NotImplemented, + 'L': '.000', + 'm': '01', + 'M': '04', + 'p': 'PM', + 'S': '05', + 'U': NotImplemented, + 'W': NotImplemented, + 'w': NotImplemented, + 'X': NotImplemented, + 'x': NotImplemented, + 'y': '06', + 'Y': '2006', + 'Z': 'MST', + 'z': '-0700', +} + + +class struct_time(tuple): #pylint: disable=invalid-name,missing-docstring + + def __init__(self, args): + super(struct_time, self).__init__(tuple, args) + self.tm_year = self[0] + self.tm_mon = self[1] + self.tm_mday = self[2] + self.tm_hour = self[3] + self.tm_min = self[4] + self.tm_sec = self[5] + self.tm_wday = self[6] + self.tm_yday = self[7] + self.tm_isdst = self[8] + + def __repr__(self): + return ("time.struct_time(tm_year=%s, tm_mon=%s, tm_mday=%s, " + "tm_hour=%s, tm_min=%s, tm_sec=%s, tm_wday=%s, " + "tm_yday=%s, tm_isdst=%s)") % self + + def __str__(self): + return repr(self) + + +def gmtime(seconds=None): + t = (Unix(seconds, 0) if seconds else Now()).UTC() + return struct_time((t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), + t.Second(), (t.Weekday() + 6) % 7, t.YearDay(), 0)) + + +def localtime(seconds=None): + t = (Unix(seconds, 0) if seconds else Now()).Local() + return struct_time((t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), + t.Second(), (t.Weekday() + 6) % 7, t.YearDay(), 0)) + + +def mktime(t): + return float(Date(t[0], t[1], t[2], t[3], t[4], t[5], 0, Local).Unix()) def sleep(secs): @@ -23,3 +90,31 @@ def sleep(secs): def time(): return float(Now().UnixNano()) / Second + + +def strftime(format, tt=None): # pylint: disable=missing-docstring,redefined-builtin + t = Unix(int(mktime(tt)), 0) if tt else Now() + ret = [] + prev, n = 0, format.find('%', 0, -1) + while n != -1: + ret.append(format[prev:n]) + next_ch = format[n + 1] + c = _strftime_directive_map.get(next_ch) + if c is NotImplemented: + raise NotImplementedError('Code: %' + next_ch + ' not yet supported') + if c: + ret.append(t.Format(c)) + else: + ret.append(format[n:n+2]) + n += 2 + prev, n = n, format.find('%', n, -1) + ret.append(format[prev:]) + return ''.join(ret) + + +# TODO: Calculate real value for daylight saving. +daylight = 0 + +# TODO: Use local DST instead of ''. +tzname = (Now().Zone()[0], '') + diff --git a/lib/time_test.py b/lib/time_test.py index bd3fdff1..4689c399 100644 --- a/lib/time_test.py +++ b/lib/time_test.py @@ -16,3 +16,7 @@ assert time.time() > 1000000000 assert time.time() < 3000000000 + +time_struct = (1999, 9, 19, 0, 0, 0, 6, 262, 0) +got = time.localtime(time.mktime(time_struct)) +assert got == time_struct, got diff --git a/lib/types_test.py b/lib/types_test.py index 64f95ce0..99b610fd 100644 --- a/lib/types_test.py +++ b/lib/types_test.py @@ -14,8 +14,8 @@ import types -from __go__.grumpy import (FunctionType, MethodType, ModuleType, StrType, # pylint: disable=g-multiple-import - TracebackType, TypeType) +from '__go__/grumpy' import (FunctionType, MethodType, ModuleType, StrType, # pylint: disable=g-multiple-import + TracebackType, TypeType) # Verify a sample of all types as a sanity check. assert types.FunctionType is FunctionType diff --git a/runtime/builtin_types.go b/runtime/builtin_types.go index 2c32eee7..888951e7 100644 --- a/runtime/builtin_types.go +++ b/runtime/builtin_types.go @@ -16,8 +16,9 @@ package grumpy import ( "fmt" + "math" "math/big" - "os" + "strings" "unicode" ) @@ -27,6 +28,11 @@ var ( builtinStr = NewStr("__builtin__") // ExceptionTypes contains all builtin exception types. ExceptionTypes []*Type + // EllipsisType is the object representing the Python 'ellipsis' type + EllipsisType = newSimpleType("ellipsis", ObjectType) + // Ellipsis is the singleton ellipsis object representing the Python + // 'Ellipsis' object. + Ellipsis = &Object{typ: EllipsisType} // NoneType is the object representing the Python 'NoneType' type. NoneType = newSimpleType("NoneType", ObjectType) // None is the singleton NoneType object representing the Python 'None' @@ -44,10 +50,23 @@ var ( UnboundLocal = newObject(unboundLocalType) ) +func ellipsisRepr(*Frame, *Object) (*Object, *BaseException) { + return NewStr("Ellipsis").ToObject(), nil +} + func noneRepr(*Frame, *Object) (*Object, *BaseException) { return NewStr("None").ToObject(), nil } +func notImplementedRepr(*Frame, *Object) (*Object, *BaseException) { + return NewStr("NotImplemented").ToObject(), nil +} + +func initEllipsisType(map[string]*Object) { + EllipsisType.flags &= ^(typeFlagInstantiable | typeFlagBasetype) + EllipsisType.slots.Repr = &unaryOpSlot{ellipsisRepr} +} + func initNoneType(map[string]*Object) { NoneType.flags &= ^(typeFlagInstantiable | typeFlagBasetype) NoneType.slots.Repr = &unaryOpSlot{noneRepr} @@ -55,6 +74,7 @@ func initNoneType(map[string]*Object) { func initNotImplementedType(map[string]*Object) { NotImplementedType.flags &= ^(typeFlagInstantiable | typeFlagBasetype) + NotImplementedType.slots.Repr = &unaryOpSlot{notImplementedRepr} } func initUnboundLocalType(map[string]*Object) { @@ -84,16 +104,20 @@ var builtinTypes = map[*Type]*builtinTypeInfo{ BaseExceptionType: {init: initBaseExceptionType, global: true}, BaseStringType: {init: initBaseStringType, global: true}, BoolType: {init: initBoolType, global: true}, + ByteArrayType: {init: initByteArrayType, global: true}, BytesWarningType: {global: true}, CodeType: {}, + ComplexType: {init: initComplexType, global: true}, ClassMethodType: {init: initClassMethodType, global: true}, DeprecationWarningType: {global: true}, dictItemIteratorType: {init: initDictItemIteratorType}, dictKeyIteratorType: {init: initDictKeyIteratorType}, dictValueIteratorType: {init: initDictValueIteratorType}, DictType: {init: initDictType, global: true}, + EllipsisType: {init: initEllipsisType, global: true}, enumerateType: {init: initEnumerateType, global: true}, EnvironmentErrorType: {global: true}, + EOFErrorType: {global: true}, ExceptionType: {global: true}, FileType: {init: initFileType, global: true}, FloatType: {init: initFloatType, global: true}, @@ -107,6 +131,7 @@ var builtinTypes = map[*Type]*builtinTypeInfo{ IndexErrorType: {global: true}, IntType: {init: initIntType, global: true}, IOErrorType: {global: true}, + KeyboardInterruptType: {global: true}, KeyErrorType: {global: true}, listIteratorType: {init: initListIteratorType}, ListType: {init: initListType, global: true}, @@ -317,20 +342,26 @@ func builtinDir(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { } d := NewDict() o := args[0] - if o.dict != nil { - raised := seqForEach(f, o.dict.ToObject(), func(k *Object) *BaseException { - return d.SetItem(f, k, None) - }) - if raised != nil { - return nil, raised + switch { + case o.isInstance(TypeType): + for _, t := range toTypeUnsafe(o).mro { + if raised := d.Update(f, t.Dict().ToObject()); raised != nil { + return nil, raised + } } - } - for _, t := range o.typ.mro { - raised := seqForEach(f, t.dict.ToObject(), func(k *Object) *BaseException { - return d.SetItem(f, k, None) - }) - if raised != nil { - return nil, raised + case o.isInstance(ModuleType): + d.Update(f, o.Dict().ToObject()) + default: + d = NewDict() + if dict := o.Dict(); dict != nil { + if raised := d.Update(f, dict.ToObject()); raised != nil { + return nil, raised + } + } + for _, t := range o.typ.mro { + if raised := d.Update(f, t.Dict().ToObject()); raised != nil { + return nil, raised + } } } l := d.Keys(f) @@ -340,6 +371,13 @@ func builtinDir(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { return l.ToObject(), nil } +func builtinDivMod(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkFunctionArgs(f, "divmod", args, ObjectType, ObjectType); raised != nil { + return nil, raised + } + return DivMod(f, args[0], args[1]) +} + func builtinFrame(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { if raised := checkFunctionArgs(f, "__frame__", args); raised != nil { return nil, raised @@ -401,19 +439,7 @@ func builtinHex(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { if raised := checkFunctionArgs(f, "hex", args, ObjectType); raised != nil { return nil, raised } - if method, raised := args[0].typ.mroLookup(f, NewStr("__hex__")); raised != nil { - return nil, raised - } else if method != nil { - return method.Call(f, args, nil) - } - if !args[0].isInstance(IntType) && !args[0].isInstance(LongType) { - return nil, f.RaiseType(TypeErrorType, "hex() argument can't be converted to hex") - } - s := numberToBase("0x", 16, args[0]) - if args[0].isInstance(LongType) { - s += "L" - } - return NewStr(s).ToObject(), nil + return Hex(f, args[0]) } func builtinID(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { @@ -490,23 +516,7 @@ func builtinOct(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { if raised := checkFunctionArgs(f, "oct", args, ObjectType); raised != nil { return nil, raised } - if method, raised := args[0].typ.mroLookup(f, NewStr("__oct__")); raised != nil { - return nil, raised - } else if method != nil { - return method.Call(f, args, nil) - } - if !args[0].isInstance(IntType) && !args[0].isInstance(LongType) { - return nil, f.RaiseType(TypeErrorType, "oct() argument can't be converted to oct") - } - s := numberToBase("0", 8, args[0]) - if args[0].isInstance(LongType) { - s += "L" - } - // For oct(0), return "0", not "00". - if s == "00" { - s = "0" - } - return NewStr(s).ToObject(), nil + return Oct(f, args[0]) } func builtinOpen(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { @@ -539,7 +549,7 @@ func builtinOrd(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { func builtinPrint(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { sep := " " end := "\n" - file := os.Stdout + file := Stdout for _, kwarg := range kwargs { switch kwarg.Name { case "sep": @@ -571,6 +581,37 @@ func builtinRange(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) return ListType.Call(f, []*Object{r}, nil) } +func builtinRawInput(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if len(args) > 1 { + msg := fmt.Sprintf("[raw_]input expcted at most 1 arguments, got %d", len(args)) + return nil, f.RaiseType(TypeErrorType, msg) + } + + if Stdin == nil { + msg := fmt.Sprintf("[raw_]input: lost sys.stdin") + return nil, f.RaiseType(RuntimeErrorType, msg) + } + + if Stdout == nil { + msg := fmt.Sprintf("[raw_]input: lost sys.stdout") + return nil, f.RaiseType(RuntimeErrorType, msg) + } + + if len(args) == 1 { + err := pyPrint(f, args, "", "", Stdout) + if err != nil { + return nil, err + } + } + + line, err := Stdin.reader.ReadString('\n') + if err != nil { + return nil, f.RaiseType(EOFErrorType, "EOF when reading a line") + } + line = strings.TrimRight(line, "\n") + return NewStr(line).ToObject(), nil +} + func builtinRepr(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { if raised := checkFunctionArgs(f, "repr", args, ObjectType); raised != nil { return nil, raised @@ -582,6 +623,45 @@ func builtinRepr(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { return s.ToObject(), nil } +func builtinRound(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + argc := len(args) + expectedTypes := []*Type{ObjectType, ObjectType} + if argc == 1 { + expectedTypes = expectedTypes[:1] + } + if raised := checkFunctionArgs(f, "round", args, expectedTypes...); raised != nil { + return nil, raised + } + ndigits := 0 + if argc > 1 { + var raised *BaseException + if ndigits, raised = IndexInt(f, args[1]); raised != nil { + return nil, raised + } + } + number, isFloat := floatCoerce(args[0]) + + if !isFloat { + return nil, f.RaiseType(TypeErrorType, "a float is required") + } + + if math.IsNaN(number) || math.IsInf(number, 0) || number == 0.0 { + return NewFloat(number).ToObject(), nil + } + + neg := false + if number < 0 { + neg = true + number = -number + } + pow := math.Pow(10.0, float64(ndigits)) + result := math.Floor(number*pow+0.5) / pow + if neg { + result = -result + } + return NewFloat(result).ToObject(), nil +} + func builtinSetAttr(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { if raised := checkFunctionArgs(f, "setattr", args, ObjectType, StrType, ObjectType); raised != nil { return nil, raised @@ -602,6 +682,35 @@ func builtinSorted(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { return result, nil } +func builtinSum(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + argc := len(args) + expectedTypes := []*Type{ObjectType, ObjectType} + if argc == 1 { + expectedTypes = expectedTypes[:1] + } + if raised := checkFunctionArgs(f, "sum", args, expectedTypes...); raised != nil { + return nil, raised + } + var result *Object + if argc > 1 { + if args[1].typ == StrType { + return nil, f.RaiseType(TypeErrorType, "sum() can't sum strings [use ''.join(seq) instead]") + } + result = args[1] + } else { + result = NewInt(0).ToObject() + } + raised := seqForEach(f, args[0], func(o *Object) (raised *BaseException) { + result, raised = Add(f, result, o) + return raised + }) + + if raised != nil { + return nil, raised + } + return result, nil +} + func builtinUniChr(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { if raised := checkFunctionArgs(f, "unichr", args, IntType); raised != nil { return nil, raised @@ -631,9 +740,9 @@ Outer: elem, raised := Next(f, iter) if raised != nil { if raised.isInstance(StopIterationType) { + f.RestoreExc(nil, nil) break Outer } - f.RestoreExc(nil, nil) return nil, raised } elems[i] = elem @@ -645,6 +754,7 @@ Outer: func init() { builtinMap := map[string]*Object{ + "__debug__": False.ToObject(), "__frame__": newBuiltinFunction("__frame__", builtinFrame).ToObject(), "abs": newBuiltinFunction("abs", builtinAbs).ToObject(), "all": newBuiltinFunction("all", builtinAll).ToObject(), @@ -655,6 +765,8 @@ func init() { "cmp": newBuiltinFunction("cmp", builtinCmp).ToObject(), "delattr": newBuiltinFunction("delattr", builtinDelAttr).ToObject(), "dir": newBuiltinFunction("dir", builtinDir).ToObject(), + "divmod": newBuiltinFunction("divmod", builtinDivMod).ToObject(), + "Ellipsis": Ellipsis, "False": False.ToObject(), "getattr": newBuiltinFunction("getattr", builtinGetAttr).ToObject(), "globals": newBuiltinFunction("globals", builtinGlobals).ToObject(), @@ -677,9 +789,12 @@ func init() { "ord": newBuiltinFunction("ord", builtinOrd).ToObject(), "print": newBuiltinFunction("print", builtinPrint).ToObject(), "range": newBuiltinFunction("range", builtinRange).ToObject(), + "raw_input": newBuiltinFunction("raw_input", builtinRawInput).ToObject(), "repr": newBuiltinFunction("repr", builtinRepr).ToObject(), + "round": newBuiltinFunction("round", builtinRound).ToObject(), "setattr": newBuiltinFunction("setattr", builtinSetAttr).ToObject(), "sorted": newBuiltinFunction("sorted", builtinSorted).ToObject(), + "sum": newBuiltinFunction("sum", builtinSum).ToObject(), "True": True.ToObject(), "unichr": newBuiltinFunction("unichr", builtinUniChr).ToObject(), "zip": newBuiltinFunction("zip", builtinZip).ToObject(), @@ -823,9 +938,9 @@ func zipLongest(f *Frame, args Args) ([][]*Object, *BaseException) { if raised.isInstance(StopIterationType) { iters[i] = nil elems[i] = None + f.RestoreExc(nil, nil) continue } - f.RestoreExc(nil, nil) return nil, raised } noItems = false diff --git a/runtime/builtin_types_test.go b/runtime/builtin_types_test.go index d5120d71..732db19d 100644 --- a/runtime/builtin_types_test.go +++ b/runtime/builtin_types_test.go @@ -53,7 +53,7 @@ func TestBuiltinDelAttr(t *testing.T) { func TestBuiltinFuncs(t *testing.T) { f := NewRootFrame() - objectDir := ObjectType.dict.Keys(f) + objectDir := ObjectType.Dict().Keys(f) objectDir.Sort(f) fooType := newTestClass("Foo", []*Type{ObjectType}, newStringDict(map[string]*Object{"bar": None})) fooTypeDir := NewList(objectDir.elems...) @@ -64,6 +64,14 @@ func TestBuiltinFuncs(t *testing.T) { fooDir := NewList(fooTypeDir.elems...) fooDir.Append(NewStr("baz").ToObject()) fooDir.Sort(f) + dirModule := newTestModule("foo", "foo.py") + if raised := dirModule.Dict().SetItemString(NewRootFrame(), "bar", newObject(ObjectType)); raised != nil { + panic(raised) + } + dirModuleDir := dirModule.Dict().Keys(NewRootFrame()) + if raised := dirModuleDir.Sort(NewRootFrame()); raised != nil { + panic(raised) + } iter := mustNotRaise(Iter(f, mustNotRaise(xrangeType.Call(f, wrapArgs(5), nil)))) neg := wrapFuncForTest(func(f *Frame, i int) int { return -i }) raiseKey := wrapFuncForTest(func(f *Frame, o *Object) *BaseException { return f.RaiseType(RuntimeErrorType, "foo") }) @@ -75,11 +83,6 @@ func TestBuiltinFuncs(t *testing.T) { return NewStr("0octal").ToObject(), nil }).ToObject(), })) - indexType := newTestClass("Index", []*Type{ObjectType}, newStringDict(map[string]*Object{ - "__index__": newBuiltinFunction("__index__", func(f *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { - return NewInt(123).ToObject(), nil - }).ToObject(), - })) badNonZeroType := newTestClass("BadNonZeroType", []*Type{ObjectType}, newStringDict(map[string]*Object{ "__nonzero__": newBuiltinFunction("__nonzero__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { return nil, f.RaiseType(RuntimeErrorType, "foo") @@ -95,6 +98,11 @@ func TestBuiltinFuncs(t *testing.T) { return newObject(badNextType), nil }).ToObject(), })) + addType := newTestClass("Add", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__add__": newBuiltinFunction("__add__", func(f *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { + return NewInt(1).ToObject(), nil + }).ToObject(), + })) fooBuiltinFunc := newBuiltinFunction("foo", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { return newTestTuple(NewTuple(args.makeCopy()...), kwargs.makeDict()).ToObject(), nil }).ToObject() @@ -140,7 +148,7 @@ func TestBuiltinFuncs(t *testing.T) { {f: "bin", args: wrapArgs("foo"), wantExc: mustCreateException(TypeErrorType, "str object cannot be interpreted as an index")}, {f: "bin", args: wrapArgs(0.1), wantExc: mustCreateException(TypeErrorType, "float object cannot be interpreted as an index")}, {f: "bin", args: wrapArgs(1, 2, 3), wantExc: mustCreateException(TypeErrorType, "'bin' requires 1 arguments")}, - {f: "bin", args: wrapArgs(newObject(indexType)), want: NewStr("0b1111011").ToObject()}, + {f: "bin", args: wrapArgs(newTestIndexObject(123)), want: NewStr("0b1111011").ToObject()}, {f: "callable", args: wrapArgs(fooBuiltinFunc), want: True.ToObject()}, {f: "callable", args: wrapArgs(fooFunc), want: True.ToObject()}, {f: "callable", args: wrapArgs(0), want: False.ToObject()}, @@ -157,8 +165,27 @@ func TestBuiltinFuncs(t *testing.T) { {f: "chr", args: wrapArgs(), wantExc: mustCreateException(TypeErrorType, "'chr' requires 1 arguments")}, {f: "dir", args: wrapArgs(newObject(ObjectType)), want: objectDir.ToObject()}, {f: "dir", args: wrapArgs(newObject(fooType)), want: fooTypeDir.ToObject()}, + {f: "dir", args: wrapArgs(fooType), want: fooTypeDir.ToObject()}, {f: "dir", args: wrapArgs(foo), want: fooDir.ToObject()}, + {f: "dir", args: wrapArgs(dirModule), want: dirModuleDir.ToObject()}, {f: "dir", args: wrapArgs(), wantExc: mustCreateException(TypeErrorType, "'dir' requires 1 arguments")}, + {f: "divmod", args: wrapArgs(12, 7), want: NewTuple2(NewInt(1).ToObject(), NewInt(5).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(-12, 7), want: NewTuple2(NewInt(-2).ToObject(), NewInt(2).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(12, -7), want: NewTuple2(NewInt(-2).ToObject(), NewInt(-2).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(-12, -7), want: NewTuple2(NewInt(1).ToObject(), NewInt(-5).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(MaxInt, MinInt), want: NewTuple2(NewInt(-1).ToObject(), NewInt(-1).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(MinInt, MaxInt), want: NewTuple2(NewInt(-2).ToObject(), NewInt(MaxInt-1).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(MinInt, -1), want: NewTuple2(NewLong(new(big.Int).Neg(minIntBig)).ToObject(), NewLong(big.NewInt(0)).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(big.NewInt(12), big.NewInt(7)), want: NewTuple2(NewLong(big.NewInt(1)).ToObject(), NewLong(big.NewInt(5)).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(big.NewInt(-12), big.NewInt(7)), want: NewTuple2(NewLong(big.NewInt(-2)).ToObject(), NewLong(big.NewInt(2)).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(big.NewInt(12), big.NewInt(-7)), want: NewTuple2(NewLong(big.NewInt(-2)).ToObject(), NewLong(big.NewInt(-2)).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(big.NewInt(-12), big.NewInt(-7)), want: NewTuple2(NewLong(big.NewInt(1)).ToObject(), NewLong(big.NewInt(-5)).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(3.25, 1.0), want: NewTuple2(NewFloat(3.0).ToObject(), NewFloat(0.25).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(-3.25, 1.0), want: NewTuple2(NewFloat(-4.0).ToObject(), NewFloat(0.75).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(3.25, -1.0), want: NewTuple2(NewFloat(-4.0).ToObject(), NewFloat(-0.75).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(-3.25, -1.0), want: NewTuple2(NewFloat(3.0).ToObject(), NewFloat(-0.25).ToObject()).ToObject()}, + {f: "divmod", args: wrapArgs(NewStr("a"), NewStr("b")), wantExc: mustCreateException(TypeErrorType, "unsupported operand type(s) for divmod(): 'str' and 'str'")}, + {f: "divmod", args: wrapArgs(), wantExc: mustCreateException(TypeErrorType, "'divmod' requires 2 arguments")}, {f: "getattr", args: wrapArgs(None, NewStr("foo").ToObject(), NewStr("bar").ToObject()), want: NewStr("bar").ToObject()}, {f: "getattr", args: wrapArgs(None, NewStr("foo").ToObject()), wantExc: mustCreateException(AttributeErrorType, "'NoneType' object has no attribute 'foo'")}, {f: "hasattr", args: wrapArgs(newObject(ObjectType), NewStr("foo").ToObject()), want: False.ToObject()}, @@ -275,6 +302,23 @@ func TestBuiltinFuncs(t *testing.T) { {f: "repr", args: wrapArgs(NewUnicode("abc")), want: NewStr("u'abc'").ToObject()}, {f: "repr", args: wrapArgs(newTestTuple("foo", "bar")), want: NewStr("('foo', 'bar')").ToObject()}, {f: "repr", args: wrapArgs("a", "b", "c"), wantExc: mustCreateException(TypeErrorType, "'repr' requires 1 arguments")}, + {f: "round", args: wrapArgs(1234.567), want: NewFloat(1235).ToObject()}, + {f: "round", args: wrapArgs(1234.111), want: NewFloat(1234).ToObject()}, + {f: "round", args: wrapArgs(-1234.567), want: NewFloat(-1235).ToObject()}, + {f: "round", args: wrapArgs(-1234.111), want: NewFloat(-1234).ToObject()}, + {f: "round", args: wrapArgs(1234.567, newTestIndexObject(0)), want: NewFloat(1235).ToObject()}, + {f: "round", args: wrapArgs("foo"), wantExc: mustCreateException(TypeErrorType, "a float is required")}, + {f: "round", args: wrapArgs(12.5, 0), want: NewFloat(13.0).ToObject()}, + {f: "round", args: wrapArgs(-12.5, 0), want: NewFloat(-13.0).ToObject()}, + {f: "round", args: wrapArgs(12.5, 3), want: NewFloat(12.5).ToObject()}, + {f: "round", args: wrapArgs(1234.5, 1), want: NewFloat(1234.5).ToObject()}, + {f: "round", args: wrapArgs(1234.5, 1), want: NewFloat(1234.5).ToObject()}, + {f: "round", args: wrapArgs(1234.56, 1), want: NewFloat(1234.6).ToObject()}, + {f: "round", args: wrapArgs(-1234.56, 1), want: NewFloat(-1234.6).ToObject()}, + {f: "round", args: wrapArgs(-1234.56, -2), want: NewFloat(-1200.0).ToObject()}, + {f: "round", args: wrapArgs(-1234.56, -8), want: NewFloat(0.0).ToObject()}, + {f: "round", args: wrapArgs(63.4, -3), want: NewFloat(0.0).ToObject()}, + {f: "round", args: wrapArgs(63.4, -2), want: NewFloat(100.0).ToObject()}, {f: "sorted", args: wrapArgs(NewList()), want: NewList().ToObject()}, {f: "sorted", args: wrapArgs(newTestList("foo", "bar")), want: newTestList("bar", "foo").ToObject()}, {f: "sorted", args: wrapArgs(newTestList(true, false)), want: newTestList(false, true).ToObject()}, @@ -284,6 +328,13 @@ func TestBuiltinFuncs(t *testing.T) { {f: "sorted", args: wrapArgs(newTestDict("foo", 1, "bar", 2)), want: newTestList("bar", "foo").ToObject()}, {f: "sorted", args: wrapArgs(1), wantExc: mustCreateException(TypeErrorType, "'int' object is not iterable")}, {f: "sorted", args: wrapArgs(newTestList("foo", "bar"), 2), wantExc: mustCreateException(TypeErrorType, "'sorted' requires 1 arguments")}, + {f: "sum", args: wrapArgs(newTestList(1, 2, 3, 4)), want: NewInt(10).ToObject()}, + {f: "sum", args: wrapArgs(newTestList(1, 2), 3), want: NewFloat(6).ToObject()}, + {f: "sum", args: wrapArgs(newTestList(2, 1.1)), want: NewFloat(3.1).ToObject()}, + {f: "sum", args: wrapArgs(newTestList(2, 1.1, 2)), want: NewFloat(5.1).ToObject()}, + {f: "sum", args: wrapArgs(newTestList(2, 1.1, 2.0)), want: NewFloat(5.1).ToObject()}, + {f: "sum", args: wrapArgs(newTestList(1), newObject(addType)), want: NewInt(1).ToObject()}, + {f: "sum", args: wrapArgs(newTestList(newObject(addType)), newObject(addType)), want: NewInt(1).ToObject()}, {f: "unichr", args: wrapArgs(0), want: NewUnicode("\x00").ToObject()}, {f: "unichr", args: wrapArgs(65), want: NewStr("A").ToObject()}, {f: "unichr", args: wrapArgs(0x120000), wantExc: mustCreateException(ValueErrorType, "unichr() arg not in range(0x10ffff)")}, @@ -327,6 +378,13 @@ func TestBuiltinGlobals(t *testing.T) { } } +func TestEllipsisRepr(t *testing.T) { + cas := invokeTestCase{args: wrapArgs(Ellipsis), want: NewStr("Ellipsis").ToObject()} + if err := runInvokeMethodTestCase(EllipsisType, "__repr__", &cas); err != "" { + t.Error(err) + } +} + func TestNoneRepr(t *testing.T) { cas := invokeTestCase{args: wrapArgs(None), want: NewStr("None").ToObject()} if err := runInvokeMethodTestCase(NoneType, "__repr__", &cas); err != "" { @@ -334,6 +392,13 @@ func TestNoneRepr(t *testing.T) { } } +func TestNotImplementedRepr(t *testing.T) { + cas := invokeTestCase{args: wrapArgs(NotImplemented), want: NewStr("NotImplemented").ToObject()} + if err := runInvokeMethodTestCase(NotImplementedType, "__repr__", &cas); err != "" { + t.Error(err) + } +} + // captureStdout invokes a function closure which writes to stdout and captures // its output as string. func captureStdout(f *Frame, fn func() *BaseException) (string, *BaseException) { @@ -341,10 +406,10 @@ func captureStdout(f *Frame, fn func() *BaseException) (string, *BaseException) if err != nil { return "", f.RaiseType(RuntimeErrorType, fmt.Sprintf("failed to open pipe: %v", err)) } - oldStdout := os.Stdout - os.Stdout = w + oldStdout := Stdout + Stdout = NewFileFromFD(w.Fd(), nil) defer func() { - os.Stdout = oldStdout + Stdout = oldStdout }() done := make(chan struct{}) var raised *BaseException @@ -364,7 +429,8 @@ func captureStdout(f *Frame, fn func() *BaseException) (string, *BaseException) return buf.String(), nil } -func TestBuiltinPrint(t *testing.T) { +// TODO(corona10): Re-enable once #282 is addressed. +/*func TestBuiltinPrint(t *testing.T) { fun := wrapFuncForTest(func(f *Frame, args *Tuple, kwargs KWArgs) (string, *BaseException) { return captureStdout(f, func() *BaseException { _, raised := builtinPrint(NewRootFrame(), args.elems, kwargs) @@ -384,7 +450,7 @@ func TestBuiltinPrint(t *testing.T) { t.Error(err) } } -} +}*/ func TestBuiltinSetAttr(t *testing.T) { setattr := mustNotRaise(Builtins.GetItemString(NewRootFrame(), "setattr")) @@ -415,3 +481,67 @@ func TestBuiltinSetAttr(t *testing.T) { } } } + +// TODO(corona10): Re-enable once #282 is addressed. +/*func TestRawInput(t *testing.T) { + fun := wrapFuncForTest(func(f *Frame, s string, args ...*Object) (*Object, *BaseException) { + // Create a fake Stdin for input test. + stdinFile, w, err := os.Pipe() + if err != nil { + return nil, f.RaiseType(RuntimeErrorType, fmt.Sprintf("failed to open pipe: %v", err)) + } + + go func() { + w.Write([]byte(s)) + w.Close() + }() + + oldStdin := Stdin + Stdin = NewFileFromFD(stdinFile.Fd(), nil) + defer func() { + Stdin = oldStdin + stdinFile.Close() + }() + + var input *Object + output, raised := captureStdout(f, func() *BaseException { + in, raised := builtinRawInput(f, args, nil) + input = in + return raised + }) + + if raised != nil { + return nil, raised + } + + return newTestTuple(input, output).ToObject(), nil + }) + + cases := []invokeTestCase{ + {args: wrapArgs("HelloGrumpy\n", ""), want: newTestTuple("HelloGrumpy", "").ToObject()}, + {args: wrapArgs("HelloGrumpy\n", "ShouldBeShown\nShouldBeShown\t"), want: newTestTuple("HelloGrumpy", "ShouldBeShown\nShouldBeShown\t").ToObject()}, + {args: wrapArgs("HelloGrumpy\n", 5, 4), wantExc: mustCreateException(TypeErrorType, "[raw_]input expcted at most 1 arguments, got 2")}, + {args: wrapArgs("HelloGrumpy\nHelloGrumpy\n", ""), want: newTestTuple("HelloGrumpy", "").ToObject()}, + {args: wrapArgs("HelloGrumpy\nHelloGrumpy\n", "ShouldBeShown\nShouldBeShown\t"), want: newTestTuple("HelloGrumpy", "ShouldBeShown\nShouldBeShown\t").ToObject()}, + {args: wrapArgs("HelloGrumpy\nHelloGrumpy\n", 5, 4), wantExc: mustCreateException(TypeErrorType, "[raw_]input expcted at most 1 arguments, got 2")}, + {args: wrapArgs("", ""), wantExc: mustCreateException(EOFErrorType, "EOF when reading a line")}, + {args: wrapArgs("", "ShouldBeShown\nShouldBeShown\t"), wantExc: mustCreateException(EOFErrorType, "EOF when reading a line")}, + {args: wrapArgs("", 5, 4), wantExc: mustCreateException(TypeErrorType, "[raw_]input expcted at most 1 arguments, got 2")}, + } + + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } + +}*/ + +func newTestIndexObject(index int) *Object { + indexType := newTestClass("Index", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__index__": newBuiltinFunction("__index__", func(f *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { + return NewInt(index).ToObject(), nil + }).ToObject(), + })) + return newObject(indexType) +} diff --git a/runtime/bytearray.go b/runtime/bytearray.go new file mode 100644 index 00000000..c7b149cb --- /dev/null +++ b/runtime/bytearray.go @@ -0,0 +1,195 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package grumpy + +import ( + "bytes" + "fmt" + "reflect" + "sync" +) + +var ( + // ByteArrayType is the object representing the Python 'bytearray' type. + ByteArrayType = newBasisType("bytearray", reflect.TypeOf(ByteArray{}), toByteArrayUnsafe, ObjectType) +) + +// ByteArray represents Python 'bytearray' objects. +type ByteArray struct { + Object + mutex sync.RWMutex + value []byte +} + +func toByteArrayUnsafe(o *Object) *ByteArray { + return (*ByteArray)(o.toPointer()) +} + +// ToObject upcasts a to an Object. +func (a *ByteArray) ToObject() *Object { + return &a.Object +} + +// Value returns the underlying bytes held by a. +func (a *ByteArray) Value() []byte { + return a.value +} + +func byteArrayEq(f *Frame, v, w *Object) (*Object, *BaseException) { + return byteArrayCompare(v, w, False, True, False), nil +} + +func byteArrayGE(f *Frame, v, w *Object) (*Object, *BaseException) { + return byteArrayCompare(v, w, False, True, True), nil +} + +func byteArrayGetItem(f *Frame, o, key *Object) (result *Object, raised *BaseException) { + a := toByteArrayUnsafe(o) + if key.typ.slots.Index != nil { + index, raised := IndexInt(f, key) + if raised != nil { + return nil, raised + } + a.mutex.RLock() + elems := a.Value() + if index, raised = seqCheckedIndex(f, len(elems), index); raised == nil { + result = NewInt(int(elems[index])).ToObject() + } + a.mutex.RUnlock() + return result, raised + } + if key.isInstance(SliceType) { + a.mutex.RLock() + elems := a.Value() + start, stop, step, sliceLen, raised := toSliceUnsafe(key).calcSlice(f, len(elems)) + if raised == nil { + value := make([]byte, sliceLen) + if step == 1 { + copy(value, elems[start:stop]) + } else { + i := 0 + for j := start; j != stop; j += step { + value[i] = elems[j] + i++ + } + } + result = (&ByteArray{Object: Object{typ: ByteArrayType}, value: value}).ToObject() + } + a.mutex.RUnlock() + return result, raised + } + return nil, f.RaiseType(TypeErrorType, fmt.Sprintf("bytearray indices must be integers or slice, not %s", key.typ.Name())) +} + +func byteArrayGT(f *Frame, v, w *Object) (*Object, *BaseException) { + return byteArrayCompare(v, w, False, False, True), nil +} + +func byteArrayInit(f *Frame, o *Object, args Args, _ KWArgs) (*Object, *BaseException) { + if raised := checkFunctionArgs(f, "__init__", args, IntType); raised != nil { + return nil, raised + } + a := toByteArrayUnsafe(o) + a.mutex.Lock() + a.value = make([]byte, toIntUnsafe(args[0]).Value()) + a.mutex.Unlock() + return None, nil +} + +func byteArrayLE(f *Frame, v, w *Object) (*Object, *BaseException) { + return byteArrayCompare(v, w, True, True, False), nil +} + +func byteArrayLT(f *Frame, v, w *Object) (*Object, *BaseException) { + return byteArrayCompare(v, w, True, False, False), nil +} + +func byteArrayNative(f *Frame, o *Object) (reflect.Value, *BaseException) { + a := toByteArrayUnsafe(o) + a.mutex.RLock() + result := reflect.ValueOf(a.Value()) + a.mutex.RUnlock() + return result, nil +} + +func byteArrayNE(f *Frame, v, w *Object) (*Object, *BaseException) { + return byteArrayCompare(v, w, True, False, True), nil +} + +func byteArrayRepr(f *Frame, o *Object) (*Object, *BaseException) { + a := toByteArrayUnsafe(o) + a.mutex.RLock() + s, raised := Repr(f, NewStr(string(a.Value())).ToObject()) + a.mutex.RUnlock() + if raised != nil { + return nil, raised + } + return NewStr(fmt.Sprintf("bytearray(b%s)", s.Value())).ToObject(), nil +} + +func byteArrayStr(f *Frame, o *Object) (*Object, *BaseException) { + a := toByteArrayUnsafe(o) + a.mutex.RLock() + s := string(a.Value()) + a.mutex.RUnlock() + return NewStr(s).ToObject(), nil +} + +func initByteArrayType(dict map[string]*Object) { + ByteArrayType.slots.Eq = &binaryOpSlot{byteArrayEq} + ByteArrayType.slots.GE = &binaryOpSlot{byteArrayGE} + ByteArrayType.slots.GetItem = &binaryOpSlot{byteArrayGetItem} + ByteArrayType.slots.GT = &binaryOpSlot{byteArrayGT} + ByteArrayType.slots.Init = &initSlot{byteArrayInit} + ByteArrayType.slots.LE = &binaryOpSlot{byteArrayLE} + ByteArrayType.slots.LT = &binaryOpSlot{byteArrayLT} + ByteArrayType.slots.Native = &nativeSlot{byteArrayNative} + ByteArrayType.slots.NE = &binaryOpSlot{byteArrayNE} + ByteArrayType.slots.Repr = &unaryOpSlot{byteArrayRepr} + ByteArrayType.slots.Str = &unaryOpSlot{byteArrayStr} +} + +func byteArrayCompare(v, w *Object, ltResult, eqResult, gtResult *Int) *Object { + if v == w { + return eqResult.ToObject() + } + // For simplicity we make a copy of w if it's a str or bytearray. This + // is inefficient and it may be useful to optimize. + var data []byte + switch { + case w.isInstance(StrType): + data = []byte(toStrUnsafe(w).Value()) + case w.isInstance(ByteArrayType): + a := toByteArrayUnsafe(w) + a.mutex.RLock() + data = make([]byte, len(a.value)) + copy(data, a.value) + a.mutex.RUnlock() + default: + return NotImplemented + } + a := toByteArrayUnsafe(v) + a.mutex.RLock() + cmp := bytes.Compare(a.value, data) + a.mutex.RUnlock() + switch cmp { + case -1: + return ltResult.ToObject() + case 0: + return eqResult.ToObject() + default: + return gtResult.ToObject() + } +} diff --git a/runtime/bytearray_test.go b/runtime/bytearray_test.go new file mode 100644 index 00000000..8945f331 --- /dev/null +++ b/runtime/bytearray_test.go @@ -0,0 +1,118 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package grumpy + +import ( + "testing" +) + +func TestByteArrayCompare(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(newTestByteArray(""), newTestByteArray("")), want: compareAllResultEq}, + {args: wrapArgs(newTestByteArray("foo"), newTestByteArray("foo")), want: compareAllResultEq}, + {args: wrapArgs(newTestByteArray(""), newTestByteArray("foo")), want: compareAllResultLT}, + {args: wrapArgs(newTestByteArray("foo"), newTestByteArray("")), want: compareAllResultGT}, + {args: wrapArgs(newTestByteArray("bar"), newTestByteArray("baz")), want: compareAllResultLT}, + {args: wrapArgs(newTestByteArray(""), ""), want: compareAllResultEq}, + {args: wrapArgs(newTestByteArray("foo"), "foo"), want: compareAllResultEq}, + {args: wrapArgs(newTestByteArray(""), "foo"), want: compareAllResultLT}, + {args: wrapArgs(newTestByteArray("foo"), ""), want: compareAllResultGT}, + {args: wrapArgs(newTestByteArray("bar"), "baz"), want: compareAllResultLT}, + } + for _, cas := range cases { + if err := runInvokeTestCase(compareAll, &cas); err != "" { + t.Error(err) + } + } +} + +func TestByteArrayGetItem(t *testing.T) { + badIndexType := newTestClass("badIndex", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__index__": newBuiltinFunction("__index__", func(f *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { + return nil, f.RaiseType(ValueErrorType, "wut") + }).ToObject(), + })) + cases := []invokeTestCase{ + {args: wrapArgs(newTestByteArray("bar"), 1), want: NewInt(97).ToObject()}, + {args: wrapArgs(newTestByteArray("foo"), 3.14), wantExc: mustCreateException(TypeErrorType, "bytearray indices must be integers or slice, not float")}, + {args: wrapArgs(newTestByteArray("baz"), -1), want: NewInt(122).ToObject()}, + {args: wrapArgs(newTestByteArray("baz"), -4), wantExc: mustCreateException(IndexErrorType, "index out of range")}, + {args: wrapArgs(newTestByteArray(""), 0), wantExc: mustCreateException(IndexErrorType, "index out of range")}, + {args: wrapArgs(newTestByteArray("foo"), 3), wantExc: mustCreateException(IndexErrorType, "index out of range")}, + {args: wrapArgs(newTestByteArray("bar"), newTestSlice(None, 2)), want: newTestByteArray("ba").ToObject()}, + {args: wrapArgs(newTestByteArray("bar"), newTestSlice(1, 3)), want: newTestByteArray("ar").ToObject()}, + {args: wrapArgs(newTestByteArray("bar"), newTestSlice(1, None)), want: newTestByteArray("ar").ToObject()}, + {args: wrapArgs(newTestByteArray("foobarbaz"), newTestSlice(1, 8, 2)), want: newTestByteArray("obra").ToObject()}, + {args: wrapArgs(newTestByteArray("abc"), newTestSlice(None, None, -1)), want: newTestByteArray("cba").ToObject()}, + {args: wrapArgs(newTestByteArray("bar"), newTestSlice(1, 2, 0)), wantExc: mustCreateException(ValueErrorType, "slice step cannot be zero")}, + {args: wrapArgs(newTestByteArray("123"), newObject(badIndexType)), wantExc: mustCreateException(ValueErrorType, "wut")}, + } + for _, cas := range cases { + if err := runInvokeMethodTestCase(ByteArrayType, "__getitem__", &cas); err != "" { + t.Error(err) + } + } +} + +func TestByteArrayInit(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(3), want: newTestByteArray("\x00\x00\x00").ToObject()}, + {args: wrapArgs(newObject(ObjectType)), wantExc: mustCreateException(TypeErrorType, `'__init__' requires a 'int' object but received a "object"`)}, + } + for _, cas := range cases { + if err := runInvokeTestCase(ByteArrayType.ToObject(), &cas); err != "" { + t.Error(err) + } + } +} + +func TestByteArrayNative(t *testing.T) { + val, raised := ToNative(NewRootFrame(), newTestByteArray("foo").ToObject()) + if raised != nil { + t.Fatalf("bytearray.__native__ raised %v", raised) + } + data, ok := val.Interface().([]byte) + if !ok || string(data) != "foo" { + t.Fatalf(`bytearray.__native__() = %v, want []byte("123")`, val.Interface()) + } +} + +func TestByteArrayRepr(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(newTestByteArray("")), want: NewStr("bytearray(b'')").ToObject()}, + {args: wrapArgs(newTestByteArray("foo")), want: NewStr("bytearray(b'foo')").ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(Repr), &cas); err != "" { + t.Error(err) + } + } +} + +func TestByteArrayStr(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(newTestByteArray("")), want: NewStr("").ToObject()}, + {args: wrapArgs(newTestByteArray("foo")), want: NewStr("foo").ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(ToStr), &cas); err != "" { + t.Error(err) + } + } +} + +func newTestByteArray(s string) *ByteArray { + return &ByteArray{Object: Object{typ: ByteArrayType}, value: []byte(s)} +} diff --git a/runtime/code.go b/runtime/code.go index 5a523793..6ff7793d 100644 --- a/runtime/code.go +++ b/runtime/code.go @@ -75,7 +75,11 @@ func (c *Code) Eval(f *Frame, globals *Dict, args Args, kwargs KWArgs) (*Object, } } else { _, tb := f.ExcInfo() - tb = newTraceback(f, tb) + if f.code != nil { + // The root frame has no code object so don't include it + // in the traceback. + tb = newTraceback(f, tb) + } f.RestoreExc(raised, tb) } return ret, raised diff --git a/runtime/complex.go b/runtime/complex.go new file mode 100644 index 00000000..752ca4c5 --- /dev/null +++ b/runtime/complex.go @@ -0,0 +1,536 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package grumpy + +import ( + "errors" + "fmt" + "math" + "math/cmplx" + "reflect" + "regexp" + "strconv" + "strings" +) + +// ComplexType is the object representing the Python 'complex' type. +var ComplexType = newBasisType("complex", reflect.TypeOf(Complex{}), toComplexUnsafe, ObjectType) + +// Complex represents Python 'complex' objects. +type Complex struct { + Object + value complex128 +} + +// NewComplex returns a new Complex holding the given complex value. +func NewComplex(value complex128) *Complex { + return &Complex{Object{typ: ComplexType}, value} +} + +func toComplexUnsafe(o *Object) *Complex { + return (*Complex)(o.toPointer()) +} + +// ToObject upcasts c to an Object. +func (c *Complex) ToObject() *Object { + return &c.Object +} + +// Value returns the underlying complex value held by c. +func (c *Complex) Value() complex128 { + return c.value +} + +func complexAbs(f *Frame, o *Object) (*Object, *BaseException) { + c := toComplexUnsafe(o).Value() + return NewFloat(cmplx.Abs(c)).ToObject(), nil +} + +func complexAdd(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexArithmeticOp(f, "__add__", v, w, func(lhs, rhs complex128) complex128 { + return lhs + rhs + }) +} + +func complexCompareNotSupported(f *Frame, v, w *Object) (*Object, *BaseException) { + if w.isInstance(IntType) || w.isInstance(LongType) || w.isInstance(FloatType) || w.isInstance(ComplexType) { + return nil, f.RaiseType(TypeErrorType, "no ordering relation is defined for complex numbers") + } + return NotImplemented, nil +} + +func complexComplex(f *Frame, o *Object) (*Object, *BaseException) { + return o, nil +} + +func complexDiv(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexDivModOp(f, "__div__", v, w, func(v, w complex128) (complex128, bool) { + if w == 0 { + return 0, false + } + return v / w, true + }) +} + +func complexDivMod(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexDivAndModOp(f, "__divmod__", v, w, func(v, w complex128) (complex128, complex128, bool) { + if w == 0 { + return 0, 0, false + } + return complexFloorDivOp(v, w), complexModOp(v, w), true + }) +} + +func complexEq(f *Frame, v, w *Object) (*Object, *BaseException) { + e, ok := complexCompare(toComplexUnsafe(v), w) + if !ok { + return NotImplemented, nil + } + return GetBool(e).ToObject(), nil +} + +func complexFloorDiv(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexDivModOp(f, "__floordiv__", v, w, func(v, w complex128) (complex128, bool) { + if w == 0 { + return 0, false + } + return complexFloorDivOp(v, w), true + }) +} + +func complexHash(f *Frame, o *Object) (*Object, *BaseException) { + v := toComplexUnsafe(o).Value() + hashCombined := hashFloat(real(v)) + 1000003*hashFloat(imag(v)) + if hashCombined == -1 { + hashCombined = -2 + } + return NewInt(hashCombined).ToObject(), nil +} + +func complexMod(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexDivModOp(f, "__mod__", v, w, func(v, w complex128) (complex128, bool) { + if w == 0 { + return 0, false + } + return complexModOp(v, w), true + }) +} + +func complexMul(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexArithmeticOp(f, "__mul__", v, w, func(lhs, rhs complex128) complex128 { + return lhs * rhs + }) +} + +func complexNE(f *Frame, v, w *Object) (*Object, *BaseException) { + e, ok := complexCompare(toComplexUnsafe(v), w) + if !ok { + return NotImplemented, nil + } + return GetBool(!e).ToObject(), nil +} + +func complexNeg(f *Frame, o *Object) (*Object, *BaseException) { + c := toComplexUnsafe(o).Value() + return NewComplex(-c).ToObject(), nil +} + +func complexNew(f *Frame, t *Type, args Args, _ KWArgs) (*Object, *BaseException) { + argc := len(args) + if argc == 0 { + return newObject(t), nil + } + if argc > 2 { + return nil, f.RaiseType(TypeErrorType, "'__new__' of 'complex' requires at most 2 arguments") + } + if t != ComplexType { + // Allocate a plain complex then copy it's value into an object + // of the complex subtype. + x, raised := complexNew(f, ComplexType, args, nil) + if raised != nil { + return nil, raised + } + result := toComplexUnsafe(newObject(t)) + result.value = toComplexUnsafe(x).Value() + return result.ToObject(), nil + } + if complexSlot := args[0].typ.slots.Complex; complexSlot != nil && argc == 1 { + c, raised := complexConvert(complexSlot, f, args[0]) + if raised != nil { + return nil, raised + } + return c.ToObject(), nil + } + if args[0].isInstance(StrType) { + if argc > 1 { + return nil, f.RaiseType(TypeErrorType, "complex() can't take second arg if first is a string") + } + s := toStrUnsafe(args[0]).Value() + result, err := parseComplex(s) + if err != nil { + return nil, f.RaiseType(ValueErrorType, "complex() arg is a malformed string") + } + return NewComplex(result).ToObject(), nil + } + if argc > 1 && args[1].isInstance(StrType) { + return nil, f.RaiseType(TypeErrorType, "complex() second arg can't be a string") + } + cr, raised := complex128Convert(f, args[0]) + if raised != nil { + return nil, raised + } + var ci complex128 + if argc > 1 { + ci, raised = complex128Convert(f, args[1]) + if raised != nil { + return nil, raised + } + } + + // Logically it should be enough to return this: + // NewComplex(cr + ci*1i).ToObject() + // But Go complex arithmatic is not satisfying all conditions, for instance: + // cr := complex(math.Inf(1), 0) + // ci := complex(math.Inf(-1), 0) + // fmt.Println(cr + ci*1i) + // Output is (NaN-Infi), instead of (+Inf-Infi). + return NewComplex(complex(real(cr)-imag(ci), imag(cr)+real(ci))).ToObject(), nil +} + +func complexNonZero(f *Frame, o *Object) (*Object, *BaseException) { + return GetBool(toComplexUnsafe(o).Value() != 0).ToObject(), nil +} + +func complexPos(f *Frame, o *Object) (*Object, *BaseException) { + return o, nil +} + +func complexPow(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexArithmeticOp(f, "__pow__", v, w, func(lhs, rhs complex128) complex128 { + return cmplx.Pow(lhs, rhs) + }) +} + +func complexRAdd(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexArithmeticOp(f, "__radd__", v, w, func(lhs, rhs complex128) complex128 { + return lhs + rhs + }) +} + +func complexRDiv(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexDivModOp(f, "__rdiv__", v, w, func(v, w complex128) (complex128, bool) { + if v == 0 { + return 0, false + } + return w / v, true + }) +} + +func complexRDivMod(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexDivAndModOp(f, "__rdivmod__", v, w, func(v, w complex128) (complex128, complex128, bool) { + if v == 0 { + return 0, 0, false + } + return complexFloorDivOp(w, v), complexModOp(w, v), true + }) +} + +func complexRepr(f *Frame, o *Object) (*Object, *BaseException) { + c := toComplexUnsafe(o).Value() + rs, is := "", "" + pre, post := "", "" + sign := "" + if real(c) == 0.0 { + is = strconv.FormatFloat(imag(c), 'g', -1, 64) + } else { + pre = "(" + rs = strconv.FormatFloat(real(c), 'g', -1, 64) + is = strconv.FormatFloat(imag(c), 'g', -1, 64) + if imag(c) >= 0.0 || math.IsNaN(imag(c)) { + sign = "+" + } + post = ")" + } + rs = unsignPositiveInf(strings.ToLower(rs)) + is = unsignPositiveInf(strings.ToLower(is)) + return NewStr(fmt.Sprintf("%s%s%s%sj%s", pre, rs, sign, is, post)).ToObject(), nil +} + +func complexRFloorDiv(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexDivModOp(f, "__rfloordiv__", v, w, func(v, w complex128) (complex128, bool) { + if v == 0 { + return 0, false + } + return complexFloorDivOp(w, v), true + }) +} + +func complexRMod(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexDivModOp(f, "__rmod__", v, w, func(v, w complex128) (complex128, bool) { + if v == 0 { + return 0, false + } + return complexModOp(w, v), true + }) +} + +func complexRMul(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexArithmeticOp(f, "__rmul__", v, w, func(lhs, rhs complex128) complex128 { + return rhs * lhs + }) +} + +func complexRPow(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexArithmeticOp(f, "__rpow__", v, w, func(lhs, rhs complex128) complex128 { + return cmplx.Pow(rhs, lhs) + }) +} + +func complexRSub(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexArithmeticOp(f, "__rsub__", v, w, func(lhs, rhs complex128) complex128 { + return rhs - lhs + }) +} + +func complexSub(f *Frame, v, w *Object) (*Object, *BaseException) { + return complexArithmeticOp(f, "__sub__", v, w, func(lhs, rhs complex128) complex128 { + return lhs - rhs + }) +} + +func initComplexType(dict map[string]*Object) { + ComplexType.slots.Abs = &unaryOpSlot{complexAbs} + ComplexType.slots.Add = &binaryOpSlot{complexAdd} + ComplexType.slots.Complex = &unaryOpSlot{complexComplex} + ComplexType.slots.Div = &binaryOpSlot{complexDiv} + ComplexType.slots.DivMod = &binaryOpSlot{complexDivMod} + ComplexType.slots.Eq = &binaryOpSlot{complexEq} + ComplexType.slots.FloorDiv = &binaryOpSlot{complexFloorDiv} + ComplexType.slots.GE = &binaryOpSlot{complexCompareNotSupported} + ComplexType.slots.GT = &binaryOpSlot{complexCompareNotSupported} + ComplexType.slots.Hash = &unaryOpSlot{complexHash} + ComplexType.slots.LE = &binaryOpSlot{complexCompareNotSupported} + ComplexType.slots.LT = &binaryOpSlot{complexCompareNotSupported} + ComplexType.slots.Mod = &binaryOpSlot{complexMod} + ComplexType.slots.Mul = &binaryOpSlot{complexMul} + ComplexType.slots.NE = &binaryOpSlot{complexNE} + ComplexType.slots.Neg = &unaryOpSlot{complexNeg} + ComplexType.slots.New = &newSlot{complexNew} + ComplexType.slots.NonZero = &unaryOpSlot{complexNonZero} + ComplexType.slots.Pos = &unaryOpSlot{complexPos} + ComplexType.slots.Pow = &binaryOpSlot{complexPow} + ComplexType.slots.RAdd = &binaryOpSlot{complexRAdd} + ComplexType.slots.RDiv = &binaryOpSlot{complexRDiv} + ComplexType.slots.RDivMod = &binaryOpSlot{complexRDivMod} + ComplexType.slots.RFloorDiv = &binaryOpSlot{complexRFloorDiv} + ComplexType.slots.Repr = &unaryOpSlot{complexRepr} + ComplexType.slots.RMod = &binaryOpSlot{complexRMod} + ComplexType.slots.RMul = &binaryOpSlot{complexRMul} + ComplexType.slots.RPow = &binaryOpSlot{complexRPow} + ComplexType.slots.RSub = &binaryOpSlot{complexRSub} + ComplexType.slots.Sub = &binaryOpSlot{complexSub} +} + +func complex128Convert(f *Frame, o *Object) (complex128, *BaseException) { + if complexSlot := o.typ.slots.Complex; complexSlot != nil { + c, raised := complexConvert(complexSlot, f, o) + if raised != nil { + return complex(0, 0), raised + } + return c.Value(), nil + } else if floatSlot := o.typ.slots.Float; floatSlot != nil { + result, raised := floatConvert(floatSlot, f, o) + if raised != nil { + return complex(0, 0), raised + } + return complex(result.Value(), 0), nil + } else { + return complex(0, 0), f.RaiseType(TypeErrorType, "complex() argument must be a string or a number") + } +} + +func complexArithmeticOp(f *Frame, method string, v, w *Object, fun func(v, w complex128) complex128) (*Object, *BaseException) { + if w.isInstance(ComplexType) { + return NewComplex(fun(toComplexUnsafe(v).Value(), toComplexUnsafe(w).Value())).ToObject(), nil + } + + floatW, ok := floatCoerce(w) + if !ok { + if math.IsInf(floatW, 0) { + return nil, f.RaiseType(OverflowErrorType, "long int too large to convert to float") + } + return NotImplemented, nil + } + return NewComplex(fun(toComplexUnsafe(v).Value(), complex(floatW, 0))).ToObject(), nil +} + +// complexCoerce will coerce any numeric type to a complex. If all is +// well, it will return the complex128 value, and true (OK). If an overflow +// occurs, it will return either (+Inf, false) or (-Inf, false) depending +// on whether the source value was too large or too small. Note that if the +// source number is an infinite float, the result will be infinite without +// overflow, (+-Inf, true). +// If the input is not a number, it will return (0, false). +func complexCoerce(o *Object) (complex128, bool) { + if o.isInstance(ComplexType) { + return toComplexUnsafe(o).Value(), true + } + floatO, ok := floatCoerce(o) + if !ok { + if math.IsInf(floatO, 0) { + return complex(floatO, 0.0), false + } + return 0, false + } + return complex(floatO, 0.0), true +} + +func complexCompare(v *Complex, w *Object) (bool, bool) { + lhsr := real(v.Value()) + rhs, ok := complexCoerce(w) + if !ok { + return false, false + } + return lhsr == real(rhs) && imag(v.Value()) == imag(rhs), true +} + +func complexConvert(complexSlot *unaryOpSlot, f *Frame, o *Object) (*Complex, *BaseException) { + result, raised := complexSlot.Fn(f, o) + if raised != nil { + return nil, raised + } + if !result.isInstance(ComplexType) { + exc := fmt.Sprintf("__complex__ returned non-complex (type %s)", result.typ.Name()) + return nil, f.RaiseType(TypeErrorType, exc) + } + return toComplexUnsafe(result), nil +} + +func complexDivModOp(f *Frame, method string, v, w *Object, fun func(v, w complex128) (complex128, bool)) (*Object, *BaseException) { + complexW, ok := complexCoerce(w) + if !ok { + if cmplx.IsInf(complexW) { + return nil, f.RaiseType(OverflowErrorType, "long int too large to convert to complex") + } + return NotImplemented, nil + } + x, ok := fun(toComplexUnsafe(v).Value(), complexW) + if !ok { + return nil, f.RaiseType(ZeroDivisionErrorType, "complex division or modulo by zero") + } + return NewComplex(x).ToObject(), nil +} + +func complexDivAndModOp(f *Frame, method string, v, w *Object, fun func(v, w complex128) (complex128, complex128, bool)) (*Object, *BaseException) { + complexW, ok := complexCoerce(w) + if !ok { + if cmplx.IsInf(complexW) { + return nil, f.RaiseType(OverflowErrorType, "long int too large to convert to complex") + } + return NotImplemented, nil + } + q, m, ok := fun(toComplexUnsafe(v).Value(), complexW) + if !ok { + return nil, f.RaiseType(ZeroDivisionErrorType, "complex division or modulo by zero") + } + return NewTuple2(NewComplex(q).ToObject(), NewComplex(m).ToObject()).ToObject(), nil +} + +func complexFloorDivOp(v, w complex128) complex128 { + return complex(math.Floor(real(v/w)), 0) +} + +func complexModOp(v, w complex128) complex128 { + return v - complexFloorDivOp(v, w)*w +} + +const ( + blank = iota + real1 + imag1 + real2 + sign2 + imag3 + real4 + sign5 + onlyJ +) + +// ParseComplex converts the string s to a complex number. +// If string is well-formed (one of these forms: , j, +// j, j, j or j, where is +// any numeric string that's acceptable by strconv.ParseFloat(s, 64)), +// ParseComplex returns the respective complex128 number. +func parseComplex(s string) (complex128, error) { + c := strings.Count(s, "(") + if (c > 1) || (c == 1 && strings.Count(s, ")") != 1) { + return complex(0, 0), errors.New("Malformed complex string, more than one matching parantheses") + } + ts := strings.TrimSpace(s) + ts = strings.Trim(ts, "()") + ts = strings.TrimSpace(ts) + re := `(?i)(?:(?:(?:(?:\d*\.\d+)|(?:\d+\.?))(?:[Ee][+-]?\d+)?)|(?:infinity)|(?:nan)|(?:inf))` + fre := `[-+]?` + re + sre := `[-+]` + re + fsfj := `(?:(?P` + fre + `)(?P` + sre + `)j)` + fsj := `(?:(?P` + fre + `)(?P[-+])j)` + fj := `(?P` + fre + `)j` + f := `(?P` + fre + `)` + sj := `(?P[-+])j` + j := `(?Pj)` + r := regexp.MustCompile(`^(?:` + fsfj + `|` + fsj + `|` + fj + `|` + f + `|` + sj + `|` + j + `)$`) + subs := r.FindStringSubmatch(ts) + if subs == nil { + return complex(0, 0), errors.New("Malformed complex string, no mathing pattern found") + } + if subs[real1] != "" && subs[imag1] != "" { + r, _ := strconv.ParseFloat(unsignNaN(subs[real1]), 64) + i, err := strconv.ParseFloat(unsignNaN(subs[imag1]), 64) + return complex(r, i), err + } + if subs[real2] != "" && subs[sign2] != "" { + r, err := strconv.ParseFloat(unsignNaN(subs[real2]), 64) + if subs[sign2] == "-" { + return complex(r, -1), err + } + return complex(r, 1), err + } + if subs[imag3] != "" { + i, err := strconv.ParseFloat(unsignNaN(subs[imag3]), 64) + return complex(0, i), err + } + if subs[real4] != "" { + r, err := strconv.ParseFloat(unsignNaN(subs[real4]), 64) + return complex(r, 0), err + } + if subs[sign5] != "" { + if subs[sign5] == "-" { + return complex(0, -1), nil + } + return complex(0, 1), nil + } + if subs[onlyJ] != "" { + return complex(0, 1), nil + } + return complex(0, 0), errors.New("Malformed complex string") +} + +func unsignNaN(s string) string { + ls := strings.ToLower(s) + if ls == "-nan" || ls == "+nan" { + return "nan" + } + return s +} diff --git a/runtime/complex_test.go b/runtime/complex_test.go new file mode 100644 index 00000000..61b68722 --- /dev/null +++ b/runtime/complex_test.go @@ -0,0 +1,532 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package grumpy + +import ( + "errors" + "math" + "math/big" + "math/cmplx" + "testing" +) + +func TestComplexAbs(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(complex(0, 0)), want: NewFloat(0).ToObject()}, + {args: wrapArgs(complex(1, 1)), want: NewFloat(1.4142135623730951).ToObject()}, + {args: wrapArgs(complex(1, 2)), want: NewFloat(2.23606797749979).ToObject()}, + {args: wrapArgs(complex(3, 4)), want: NewFloat(5).ToObject()}, + {args: wrapArgs(complex(-3, 4)), want: NewFloat(5).ToObject()}, + {args: wrapArgs(complex(3, -4)), want: NewFloat(5).ToObject()}, + {args: wrapArgs(-complex(3, 4)), want: NewFloat(5).ToObject()}, + {args: wrapArgs(complex(0.123456e-3, 0)), want: NewFloat(0.000123456).ToObject()}, + {args: wrapArgs(complex(0.123456e-3, 3.14151692e+7)), want: NewFloat(31415169.2).ToObject()}, + {args: wrapArgs(complex(math.Inf(-1), 1.2)), want: NewFloat(math.Inf(1)).ToObject()}, + {args: wrapArgs(complex(3.4, math.Inf(1))), want: NewFloat(math.Inf(1)).ToObject()}, + {args: wrapArgs(complex(math.Inf(1), math.Inf(-1))), want: NewFloat(math.Inf(1)).ToObject()}, + {args: wrapArgs(complex(math.Inf(1), math.NaN())), want: NewFloat(math.Inf(1)).ToObject()}, + {args: wrapArgs(complex(math.NaN(), math.Inf(1))), want: NewFloat(math.Inf(1)).ToObject()}, + {args: wrapArgs(complex(math.NaN(), 5.6)), want: NewFloat(math.NaN()).ToObject()}, + {args: wrapArgs(complex(7.8, math.NaN())), want: NewFloat(math.NaN()).ToObject()}, + } + for _, cas := range cases { + switch got, match := checkInvokeResult(wrapFuncForTest(complexAbs), cas.args, cas.want, cas.wantExc); match { + case checkInvokeResultReturnValueMismatch: + if got == nil || cas.want == nil || !got.isInstance(FloatType) || !cas.want.isInstance(FloatType) || + !floatsAreSame(toFloatUnsafe(got).Value(), toFloatUnsafe(cas.want).Value()) { + t.Errorf("complex.__abs__%v = %v, want %v", cas.args, got, cas.want) + } + } + } +} + +func TestComplexEq(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(complex(0, 0), 0), want: True.ToObject()}, + {args: wrapArgs(complex(1, 0), 0), want: False.ToObject()}, + {args: wrapArgs(complex(-12, 0), -12), want: True.ToObject()}, + {args: wrapArgs(complex(-12, 0), 1), want: False.ToObject()}, + {args: wrapArgs(complex(17.20, 0), 17.20), want: True.ToObject()}, + {args: wrapArgs(complex(1.2, 0), 17.20), want: False.ToObject()}, + {args: wrapArgs(complex(-4, 15), complex(-4, 15)), want: True.ToObject()}, + {args: wrapArgs(complex(-4, 15), complex(1, 2)), want: False.ToObject()}, + {args: wrapArgs(complex(math.Inf(1), 0), complex(math.Inf(1), 0)), want: True.ToObject()}, + {args: wrapArgs(complex(math.Inf(1), 0), complex(0, math.Inf(1))), want: False.ToObject()}, + {args: wrapArgs(complex(math.Inf(-1), 0), complex(math.Inf(-1), 0)), want: True.ToObject()}, + {args: wrapArgs(complex(math.Inf(-1), 0), complex(0, math.Inf(-1))), want: False.ToObject()}, + {args: wrapArgs(complex(math.Inf(1), math.Inf(1)), complex(math.Inf(1), math.Inf(1))), want: True.ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(complexEq), &cas); err != "" { + t.Error(err) + } + } +} + +// FIXME(corona10): Since Go 1.9 moved to C99 float division and what CPython uses as well. +// Some tests can be failed with version < Go 1.9. We need to detect Go version. +// And changed expected values. + +func TestComplexBinaryOps(t *testing.T) { + cases := []struct { + fun func(f *Frame, v, w *Object) (*Object, *BaseException) + v, w *Object + want *Object + wantExc *BaseException + }{ + {Add, NewComplex(1 + 3i).ToObject(), NewInt(1).ToObject(), NewComplex(2 + 3i).ToObject(), nil}, + {Add, NewComplex(1 + 3i).ToObject(), NewFloat(-1).ToObject(), NewComplex(3i).ToObject(), nil}, + {Add, NewComplex(1 + 3i).ToObject(), NewInt(1).ToObject(), NewComplex(2 + 3i).ToObject(), nil}, + {Add, NewComplex(1 + 3i).ToObject(), NewComplex(-1 - 3i).ToObject(), NewComplex(0i).ToObject(), nil}, + {Add, NewFloat(math.Inf(1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.Inf(1), 3)).ToObject(), nil}, + {Add, NewFloat(math.Inf(-1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.Inf(-1), 3)).ToObject(), nil}, + {Add, NewFloat(math.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.NaN(), 3)).ToObject(), nil}, + {Add, NewComplex(cmplx.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(cmplx.NaN()).ToObject(), nil}, + {Add, NewFloat(math.Inf(-1)).ToObject(), NewComplex(complex(math.Inf(+1), 3)).ToObject(), NewComplex(complex(math.NaN(), 3)).ToObject(), nil}, + {Add, NewComplex(1 + 3i).ToObject(), None, nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for +: 'complex' and 'NoneType'")}, + {Add, None, NewComplex(1 + 3i).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for +: 'NoneType' and 'complex'")}, + {Add, NewInt(3).ToObject(), NewComplex(3i).ToObject(), NewComplex(3 + 3i).ToObject(), nil}, + {Add, NewLong(big.NewInt(9999999)).ToObject(), NewComplex(3i).ToObject(), NewComplex(9999999 + 3i).ToObject(), nil}, + {Add, NewFloat(3.5).ToObject(), NewComplex(3i).ToObject(), NewComplex(3.5 + 3i).ToObject(), nil}, + {Div, NewComplex(1 + 2i).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(1 + 0i).ToObject(), nil}, + {Div, NewComplex(3 + 4i).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(2.2 - 0.4i).ToObject(), nil}, + {Div, NewComplex(3.14 - 0.618i).ToObject(), NewComplex(-0.123e-4 + 0.151692i).ToObject(), NewComplex(-4.075723201992163 - 20.69950866627519i).ToObject(), nil}, + {Div, NewInt(3).ToObject(), NewComplex(3 - 4i).ToObject(), NewComplex(0.36 + 0.48i).ToObject(), nil}, + {Div, NewComplex(3 + 4i).ToObject(), NewInt(-5).ToObject(), NewComplex(-0.6 - 0.8i).ToObject(), nil}, + {Div, NewFloat(1.2).ToObject(), NewComplex(1 - 2i).ToObject(), NewComplex(0.24 + 0.48i).ToObject(), nil}, + {Div, NewComplex(1 + 2i).ToObject(), NewFloat(-3.4).ToObject(), NewComplex(-0.29411764705882354 - 0.5882352941176471i).ToObject(), nil}, + {Div, NewLong(big.NewInt(123)).ToObject(), NewComplex(3 + 4i).ToObject(), NewComplex(14.76 - 19.68i).ToObject(), nil}, + {Div, NewComplex(3 - 4i).ToObject(), NewLong(big.NewInt(-34)).ToObject(), NewComplex(-0.08823529411764706 + 0.11764705882352941i).ToObject(), nil}, + {Div, NewComplex(3 + 4i).ToObject(), NewComplex(complex(math.Inf(1), math.Inf(-1))).ToObject(), NewComplex(0i).ToObject(), nil}, + {Div, NewComplex(3 + 4i).ToObject(), NewComplex(complex(math.Inf(1), 2)).ToObject(), NewComplex(0i).ToObject(), nil}, + {Div, NewComplex(complex(math.Inf(1), math.Inf(1))).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(complex(math.Inf(1), math.NaN())).ToObject(), nil}, + {Div, NewComplex(complex(math.Inf(1), 4)).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(complex(math.Inf(1), math.Inf(-1))).ToObject(), nil}, + {Div, NewComplex(complex(3, math.Inf(1))).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(complex(math.Inf(1), math.Inf(1))).ToObject(), nil}, + {Div, NewComplex(complex(3, math.NaN())).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(complex(math.NaN(), math.NaN())).ToObject(), nil}, + {Div, NewStr("foo").ToObject(), NewComplex(1 + 2i).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for /: 'str' and 'complex'")}, + {Div, NewComplex(3 + 4i).ToObject(), NewComplex(0 + 0i).ToObject(), nil, mustCreateException(ZeroDivisionErrorType, "complex division or modulo by zero")}, + {Div, NewComplex(complex(math.Inf(1), math.NaN())).ToObject(), NewComplex(0 + 0i).ToObject(), nil, mustCreateException(ZeroDivisionErrorType, "complex division or modulo by zero")}, + {Div, NewComplex(3 + 4i).ToObject(), NewLong(bigLongNumber).ToObject(), nil, mustCreateException(OverflowErrorType, "long int too large to convert to complex")}, + {FloorDiv, NewComplex(1 + 2i).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(1 + 0i).ToObject(), nil}, + {FloorDiv, NewComplex(3 + 4i).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(2 - 0i).ToObject(), nil}, + {FloorDiv, NewComplex(3.14 - 0.618i).ToObject(), NewComplex(-0.123e-4 + 0.151692i).ToObject(), NewComplex(-5 - 0i).ToObject(), nil}, + {FloorDiv, NewInt(3).ToObject(), NewComplex(3 - 4i).ToObject(), NewComplex(0i).ToObject(), nil}, + {FloorDiv, NewComplex(3 + 4i).ToObject(), NewInt(-5).ToObject(), NewComplex(-1 + 0i).ToObject(), nil}, + {FloorDiv, NewFloat(1.2).ToObject(), NewComplex(1 - 2i).ToObject(), NewComplex(0i).ToObject(), nil}, + {FloorDiv, NewComplex(1 + 2i).ToObject(), NewFloat(-3.4).ToObject(), NewComplex(-1 + 0i).ToObject(), nil}, + {FloorDiv, NewLong(big.NewInt(123)).ToObject(), NewComplex(3 + 4i).ToObject(), NewComplex(14 - 0i).ToObject(), nil}, + {FloorDiv, NewComplex(3 - 4i).ToObject(), NewLong(big.NewInt(-34)).ToObject(), NewComplex(-1 + 0i).ToObject(), nil}, + {FloorDiv, NewComplex(3 + 4i).ToObject(), NewComplex(complex(math.Inf(1), math.Inf(-1))).ToObject(), NewComplex(0i).ToObject(), nil}, + {FloorDiv, NewComplex(3 + 4i).ToObject(), NewComplex(complex(math.Inf(1), 2)).ToObject(), NewComplex(0i).ToObject(), nil}, + {FloorDiv, NewComplex(complex(math.Inf(1), math.Inf(1))).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(complex(math.Inf(1), 0)).ToObject(), nil}, + {FloorDiv, NewComplex(complex(math.Inf(1), 4)).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(complex(math.Inf(1), 0)).ToObject(), nil}, + {FloorDiv, NewComplex(complex(3, math.Inf(1))).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(complex(math.Inf(1), 0)).ToObject(), nil}, + {FloorDiv, NewComplex(complex(3, math.NaN())).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(complex(math.NaN(), 0)).ToObject(), nil}, + {FloorDiv, NewStr("foo").ToObject(), NewComplex(1 + 2i).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for //: 'str' and 'complex'")}, + {FloorDiv, NewComplex(3 + 4i).ToObject(), NewComplex(0 + 0i).ToObject(), nil, mustCreateException(ZeroDivisionErrorType, "complex division or modulo by zero")}, + {FloorDiv, NewComplex(complex(math.Inf(1), math.NaN())).ToObject(), NewComplex(0 + 0i).ToObject(), nil, mustCreateException(ZeroDivisionErrorType, "complex division or modulo by zero")}, + {FloorDiv, NewComplex(3 + 4i).ToObject(), NewLong(bigLongNumber).ToObject(), nil, mustCreateException(OverflowErrorType, "long int too large to convert to complex")}, + {Mod, NewComplex(3 + 4i).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(1 + 0i).ToObject(), nil}, + {Mod, NewComplex(1 + 2i).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(0i).ToObject(), nil}, + {Mod, NewComplex(3.14 - 0.618i).ToObject(), NewComplex(-0.123e-4 + 0.151692i).ToObject(), NewComplex(3.1399385 + 0.14045999999999992i).ToObject(), nil}, + {Mod, NewInt(3).ToObject(), NewComplex(3 - 4i).ToObject(), NewComplex(3 + 0i).ToObject(), nil}, + {Mod, NewComplex(3 + 4i).ToObject(), NewInt(-5).ToObject(), NewComplex(-2 + 4i).ToObject(), nil}, + {Mod, NewFloat(1.2).ToObject(), NewComplex(1 - 2i).ToObject(), NewComplex(1.2 + 0i).ToObject(), nil}, + {Mod, NewComplex(1 + 2i).ToObject(), NewFloat(-3.4).ToObject(), NewComplex(-2.4 + 2i).ToObject(), nil}, + {Mod, NewLong(big.NewInt(123)).ToObject(), NewComplex(3 + 4i).ToObject(), NewComplex(81 - 56i).ToObject(), nil}, + {Mod, NewComplex(3 - 4i).ToObject(), NewLong(big.NewInt(-34)).ToObject(), NewComplex(-31 - 4i).ToObject(), nil}, + {Mod, NewComplex(3 + 4i).ToObject(), NewComplex(complex(math.Inf(1), math.Inf(-1))).ToObject(), NewComplex(complex(math.NaN(), math.NaN())).ToObject(), nil}, + {Mod, NewComplex(3 + 4i).ToObject(), NewComplex(complex(math.Inf(1), 2)).ToObject(), NewComplex(complex(math.NaN(), math.NaN())).ToObject(), nil}, + {Mod, NewComplex(complex(math.Inf(1), math.Inf(1))).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(complex(math.NaN(), math.NaN())).ToObject(), nil}, + {Mod, NewComplex(complex(math.Inf(1), 4)).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(complex(math.NaN(), math.Inf(-1))).ToObject(), nil}, + {Mod, NewComplex(complex(3, math.Inf(1))).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(complex(math.Inf(-1), math.NaN())).ToObject(), nil}, + {Mod, NewComplex(complex(3, math.NaN())).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(complex(math.NaN(), math.NaN())).ToObject(), nil}, + {Mod, NewStr("foo").ToObject(), NewComplex(1 + 2i).ToObject(), nil, mustCreateException(TypeErrorType, "not all arguments converted during string formatting")}, + {Mod, NewComplex(3 + 4i).ToObject(), NewComplex(0 + 0i).ToObject(), nil, mustCreateException(ZeroDivisionErrorType, "complex division or modulo by zero")}, + {Mod, NewComplex(complex(math.Inf(1), math.NaN())).ToObject(), NewComplex(0 + 0i).ToObject(), nil, mustCreateException(ZeroDivisionErrorType, "complex division or modulo by zero")}, + {Mod, NewComplex(3 + 4i).ToObject(), NewLong(bigLongNumber).ToObject(), nil, mustCreateException(OverflowErrorType, "long int too large to convert to complex")}, + {Sub, NewComplex(1 + 3i).ToObject(), NewComplex(1 + 3i).ToObject(), NewComplex(0i).ToObject(), nil}, + {Sub, NewComplex(1 + 3i).ToObject(), NewComplex(3i).ToObject(), NewComplex(1).ToObject(), nil}, + {Sub, NewComplex(1 + 3i).ToObject(), NewFloat(1).ToObject(), NewComplex(3i).ToObject(), nil}, + {Sub, NewComplex(3i).ToObject(), NewFloat(1.2).ToObject(), NewComplex(-1.2 + 3i).ToObject(), nil}, + {Sub, NewComplex(1 + 3i).ToObject(), NewComplex(1 + 3i).ToObject(), NewComplex(0i).ToObject(), nil}, + {Sub, NewComplex(4 + 3i).ToObject(), NewInt(1).ToObject(), NewComplex(3 + 3i).ToObject(), nil}, + {Sub, NewComplex(4 + 3i).ToObject(), NewLong(big.NewInt(99994)).ToObject(), NewComplex(-99990 + 3i).ToObject(), nil}, + {Sub, NewFloat(math.Inf(1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.Inf(1), -3)).ToObject(), nil}, + {Sub, NewFloat(math.Inf(-1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.Inf(-1), -3)).ToObject(), nil}, + {Sub, NewComplex(1 + 3i).ToObject(), None, nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for -: 'complex' and 'NoneType'")}, + {Sub, None, NewComplex(1 + 3i).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for -: 'NoneType' and 'complex'")}, + {Sub, NewFloat(math.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.NaN(), -3)).ToObject(), nil}, + {Sub, NewComplex(cmplx.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(cmplx.NaN()).ToObject(), nil}, + {Sub, NewFloat(math.Inf(-1)).ToObject(), NewComplex(complex(math.Inf(-1), 3)).ToObject(), NewComplex(complex(math.NaN(), -3)).ToObject(), nil}, + {Mul, NewComplex(1 + 3i).ToObject(), NewComplex(1 + 3i).ToObject(), NewComplex(-8 + 6i).ToObject(), nil}, + {Mul, NewComplex(1 + 3i).ToObject(), NewComplex(3i).ToObject(), NewComplex(-9 + 3i).ToObject(), nil}, + {Mul, NewComplex(1 + 3i).ToObject(), NewFloat(1).ToObject(), NewComplex(1 + 3i).ToObject(), nil}, + {Mul, NewComplex(3i).ToObject(), NewFloat(1.2).ToObject(), NewComplex(3.5999999999999996i).ToObject(), nil}, + {Mul, NewComplex(1 + 3i).ToObject(), NewComplex(1 + 3i).ToObject(), NewComplex(-8 + 6i).ToObject(), nil}, + {Mul, NewComplex(4 + 3i).ToObject(), NewInt(1).ToObject(), NewComplex(4 + 3i).ToObject(), nil}, + {Mul, NewComplex(4 + 3i).ToObject(), NewLong(big.NewInt(99994)).ToObject(), NewComplex(399976 + 299982i).ToObject(), nil}, + {Mul, NewFloat(math.Inf(1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.NaN(), math.Inf(1))).ToObject(), nil}, + {Mul, NewFloat(math.Inf(-1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.NaN(), math.Inf(-1))).ToObject(), nil}, + {Mul, NewComplex(1 + 3i).ToObject(), None, nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for *: 'complex' and 'NoneType'")}, + {Mul, None, NewComplex(1 + 3i).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for *: 'NoneType' and 'complex'")}, + {Mul, NewFloat(math.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.NaN(), math.NaN())).ToObject(), nil}, + {Mul, NewComplex(cmplx.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(cmplx.NaN()).ToObject(), nil}, + {Mul, NewFloat(math.Inf(-1)).ToObject(), NewComplex(complex(math.Inf(-1), 3)).ToObject(), NewComplex(complex(math.Inf(1), math.NaN())).ToObject(), nil}, + {Pow, NewComplex(0i).ToObject(), NewComplex(0i).ToObject(), NewComplex(1 + 0i).ToObject(), nil}, + {Pow, NewComplex(-1 + 0i).ToObject(), NewComplex(1i).ToObject(), NewComplex(0.04321391826377226 + 0i).ToObject(), nil}, + {Pow, NewComplex(1 + 2i).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(-0.22251715680177264 + 0.10070913113607538i).ToObject(), nil}, + {Pow, NewComplex(0i).ToObject(), NewComplex(-1 + 0i).ToObject(), NewComplex(complex(math.Inf(1), 0)).ToObject(), nil}, + {Pow, NewComplex(0i).ToObject(), NewComplex(-1 + 1i).ToObject(), NewComplex(complex(math.Inf(1), math.Inf(1))).ToObject(), nil}, + {Pow, NewComplex(complex(math.Inf(-1), 2)).ToObject(), NewComplex(1 + 2i).ToObject(), NewComplex(complex(math.NaN(), math.NaN())).ToObject(), nil}, + {Pow, NewComplex(1 + 2i).ToObject(), NewComplex(complex(1, math.Inf(1))).ToObject(), NewComplex(complex(math.NaN(), math.NaN())).ToObject(), nil}, + {Pow, NewComplex(complex(math.NaN(), 1)).ToObject(), NewComplex(3 + 4i).ToObject(), NewComplex(complex(math.NaN(), math.NaN())).ToObject(), nil}, + {Pow, NewComplex(3 + 4i).ToObject(), NewInt(3).ToObject(), NewComplex(-117 + 44.00000000000003i).ToObject(), nil}, + {Pow, NewComplex(3 + 4i).ToObject(), NewFloat(3.1415).ToObject(), NewComplex(-152.8892667678244 + 35.555335130496516i).ToObject(), nil}, + {Pow, NewComplex(3 + 4i).ToObject(), NewLong(big.NewInt(123)).ToObject(), NewComplex(5.393538720276193e+85 + 7.703512580443326e+85i).ToObject(), nil}, + {Pow, NewComplex(1 + 2i).ToObject(), NewStr("foo").ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for **: 'complex' and 'str'")}, + {Pow, NewStr("foo").ToObject(), NewComplex(1 + 2i).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for **: 'str' and 'complex'")}, + } + + for _, cas := range cases { + switch got, result := checkInvokeResult(wrapFuncForTest(cas.fun), []*Object{cas.v, cas.w}, cas.want, cas.wantExc); result { + case checkInvokeResultExceptionMismatch: + t.Errorf("%s(%v, %v) raised %v, want %v", getFuncName(cas.fun), cas.v, cas.w, got, cas.wantExc) + case checkInvokeResultReturnValueMismatch: + if got == nil || cas.want == nil || !got.isInstance(ComplexType) || !cas.want.isInstance(ComplexType) || + !complexesAreSame(toComplexUnsafe(got).Value(), toComplexUnsafe(cas.want).Value()) { + t.Errorf("%s(%v, %v) = %v, want %v", getFuncName(cas.fun), cas.v, cas.w, got, cas.want) + } + } + } +} + +func TestComplexCompareNotSupported(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(complex(1, 2), 1), wantExc: mustCreateException(TypeErrorType, "no ordering relation is defined for complex numbers")}, + {args: wrapArgs(complex(1, 2), 1.2), wantExc: mustCreateException(TypeErrorType, "no ordering relation is defined for complex numbers")}, + {args: wrapArgs(complex(1, 2), math.NaN()), wantExc: mustCreateException(TypeErrorType, "no ordering relation is defined for complex numbers")}, + {args: wrapArgs(complex(1, 2), math.Inf(-1)), wantExc: mustCreateException(TypeErrorType, "no ordering relation is defined for complex numbers")}, + {args: wrapArgs(complex(1, 2), "abc"), want: NotImplemented}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(complexCompareNotSupported), &cas); err != "" { + t.Error(err) + } + } +} + +func TestComplexDivMod(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs((1 + 2i), (1 + 2i)), want: NewTuple2(NewComplex(1+0i).ToObject(), NewComplex(0i).ToObject()).ToObject()}, + {args: wrapArgs((3 + 4i), (1 + 2i)), want: NewTuple2(NewComplex(2-0i).ToObject(), NewComplex(1+0i).ToObject()).ToObject()}, + {args: wrapArgs((3.14 - 0.618i), (-0.123e-4 + 0.151692i)), want: NewTuple2(NewComplex(-5-0i).ToObject(), NewComplex(3.1399385+0.14045999999999992i).ToObject()).ToObject()}, + {args: wrapArgs(3, (3 - 4i)), want: NewTuple2(NewComplex(0i).ToObject(), NewComplex(3+0i).ToObject()).ToObject()}, + {args: wrapArgs((3 + 4i), -5), want: NewTuple2(NewComplex(-1+0i).ToObject(), NewComplex(-2+4i).ToObject()).ToObject()}, + {args: wrapArgs(1.2, (1 - 2i)), want: NewTuple2(NewComplex(0i).ToObject(), NewComplex(1.2+0i).ToObject()).ToObject()}, + {args: wrapArgs((1 + 2i), -3.4), want: NewTuple2(NewComplex(-1+0i).ToObject(), NewComplex(-2.4+2i).ToObject()).ToObject()}, + {args: wrapArgs(big.NewInt(123), (3 + 4i)), want: NewTuple2(NewComplex(14-0i).ToObject(), NewComplex(81-56i).ToObject()).ToObject()}, + {args: wrapArgs((3 - 4i), big.NewInt(-34)), want: NewTuple2(NewComplex(-1+0i).ToObject(), NewComplex(-31-4i).ToObject()).ToObject()}, + {args: wrapArgs((3 + 4i), complex(math.Inf(1), math.Inf(-1))), want: NewTuple2(NewComplex(0i).ToObject(), NewComplex(complex(math.NaN(), math.NaN())).ToObject()).ToObject()}, + {args: wrapArgs((3 + 4i), complex(math.Inf(1), 2)), want: NewTuple2(NewComplex(0i).ToObject(), NewComplex(complex(math.NaN(), math.NaN())).ToObject()).ToObject()}, + {args: wrapArgs(complex(math.Inf(1), math.Inf(1)), (1 + 2i)), want: NewTuple2(NewComplex(complex(math.Inf(1), 0)).ToObject(), NewComplex(complex(math.NaN(), math.NaN())).ToObject()).ToObject()}, + {args: wrapArgs(complex(math.Inf(1), 4), (1 + 2i)), want: NewTuple2(NewComplex(complex(math.Inf(1), 0)).ToObject(), NewComplex(complex(math.NaN(), math.Inf(-1))).ToObject()).ToObject()}, + {args: wrapArgs(complex(3, math.Inf(1)), (1 + 2i)), want: NewTuple2(NewComplex(complex(math.Inf(1), 0)).ToObject(), NewComplex(complex(math.Inf(-1), math.NaN())).ToObject()).ToObject()}, + {args: wrapArgs(complex(3, math.NaN()), (1 + 2i)), want: NewTuple2(NewComplex(complex(math.NaN(), 0)).ToObject(), NewComplex(complex(math.NaN(), math.NaN())).ToObject()).ToObject()}, + {args: wrapArgs("foo", (1 + 2i)), wantExc: mustCreateException(TypeErrorType, "unsupported operand type(s) for divmod(): 'str' and 'complex'")}, + {args: wrapArgs((3 + 4i), (0 + 0i)), wantExc: mustCreateException(ZeroDivisionErrorType, "complex division or modulo by zero")}, + {args: wrapArgs(complex(math.Inf(1), math.NaN()), (0 + 0i)), wantExc: mustCreateException(ZeroDivisionErrorType, "complex division or modulo by zero")}, + {args: wrapArgs((3 + 4i), bigLongNumber), wantExc: mustCreateException(OverflowErrorType, "long int too large to convert to complex")}, + } + for _, cas := range cases { + switch got, result := checkInvokeResult(wrapFuncForTest(DivMod), cas.args, cas.want, cas.wantExc); result { + case checkInvokeResultExceptionMismatch: + t.Errorf("complex.__divmod__%v raised %v, want %v", cas.args, got, cas.wantExc) + case checkInvokeResultReturnValueMismatch: + // Handle NaN specially, since NaN != NaN. + if got == nil || cas.want == nil || !got.isInstance(TupleType) || !cas.want.isInstance(TupleType) || !tupleComplexesAreSame(got, cas.want) { + t.Errorf("complex.__divmod__%v = %v, want %v", cas.args, got, cas.want) + } + } + } +} + +func TestComplexNE(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(complex(0, 0), 0), want: False.ToObject()}, + {args: wrapArgs(complex(1, 0), 0), want: True.ToObject()}, + {args: wrapArgs(complex(-12, 0), -12), want: False.ToObject()}, + {args: wrapArgs(complex(-12, 0), 1), want: True.ToObject()}, + {args: wrapArgs(complex(17.20, 0), 17.20), want: False.ToObject()}, + {args: wrapArgs(complex(1.2, 0), 17.20), want: True.ToObject()}, + {args: wrapArgs(complex(-4, 15), complex(-4, 15)), want: False.ToObject()}, + {args: wrapArgs(complex(-4, 15), complex(1, 2)), want: True.ToObject()}, + {args: wrapArgs(complex(math.Inf(1), 0), complex(math.Inf(1), 0)), want: False.ToObject()}, + {args: wrapArgs(complex(math.Inf(1), 0), complex(0, math.Inf(1))), want: True.ToObject()}, + {args: wrapArgs(complex(math.Inf(-1), 0), complex(math.Inf(-1), 0)), want: False.ToObject()}, + {args: wrapArgs(complex(math.Inf(-1), 0), complex(0, math.Inf(-1))), want: True.ToObject()}, + {args: wrapArgs(complex(math.Inf(1), math.Inf(1)), complex(math.Inf(1), math.Inf(1))), want: False.ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(complexNE), &cas); err != "" { + t.Error(err) + } + } +} + +func TestComplexNew(t *testing.T) { + complexNew := mustNotRaise(GetAttr(NewRootFrame(), ComplexType.ToObject(), NewStr("__new__"), nil)) + goodSlotType := newTestClass("GoodSlot", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__complex__": newBuiltinFunction("__complex__", func(_ *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { + return NewComplex(complex(1, 2)).ToObject(), nil + }).ToObject(), + })) + badSlotType := newTestClass("BadSlot", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__complex__": newBuiltinFunction("__complex__", func(_ *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { + return newObject(ObjectType), nil + }).ToObject(), + })) + strictEqType := newTestClassStrictEq("StrictEq", ComplexType) + newStrictEq := func(v complex128) *Object { + f := Complex{Object: Object{typ: strictEqType}, value: v} + return f.ToObject() + } + subType := newTestClass("SubType", []*Type{ComplexType}, newStringDict(map[string]*Object{})) + subTypeObject := (&Complex{Object: Object{typ: subType}, value: 3.14}).ToObject() + slotSubTypeType := newTestClass("SlotSubType", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__complex__": newBuiltinFunction("__complex__", func(_ *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { + return subTypeObject, nil + }).ToObject(), + })) + cases := []invokeTestCase{ + {args: wrapArgs(ComplexType), want: NewComplex(0).ToObject()}, + {args: wrapArgs(ComplexType, 56), want: NewComplex(complex(56, 0)).ToObject()}, + {args: wrapArgs(ComplexType, -12), want: NewComplex(complex(-12, 0)).ToObject()}, + {args: wrapArgs(ComplexType, 3.14), want: NewComplex(complex(3.14, 0)).ToObject()}, + {args: wrapArgs(ComplexType, -703.4), want: NewComplex(complex(-703.4, 0)).ToObject()}, + {args: wrapArgs(ComplexType, math.NaN()), want: NewComplex(complex(math.NaN(), 0)).ToObject()}, + {args: wrapArgs(ComplexType, math.Inf(1)), want: NewComplex(complex(math.Inf(1), 0)).ToObject()}, + {args: wrapArgs(ComplexType, math.Inf(-1)), want: NewComplex(complex(math.Inf(-1), 0)).ToObject()}, + {args: wrapArgs(ComplexType, biggestFloat), want: NewComplex(complex(math.MaxFloat64, 0)).ToObject()}, + {args: wrapArgs(ComplexType, new(big.Int).Neg(biggestFloat)), want: NewComplex(complex(-math.MaxFloat64, 0)).ToObject()}, + {args: wrapArgs(ComplexType, new(big.Int).Sub(big.NewInt(-1), biggestFloat)), wantExc: mustCreateException(OverflowErrorType, "long int too large to convert to float")}, + {args: wrapArgs(ComplexType, new(big.Int).Add(biggestFloat, big.NewInt(1))), wantExc: mustCreateException(OverflowErrorType, "long int too large to convert to float")}, + {args: wrapArgs(ComplexType, bigLongNumber), wantExc: mustCreateException(OverflowErrorType, "long int too large to convert to float")}, + {args: wrapArgs(ComplexType, complex(1, 2)), want: NewComplex(complex(1, 2)).ToObject()}, + {args: wrapArgs(ComplexType, complex(-0.0001e-1, 3.14151692)), want: NewComplex(complex(-0.00001, 3.14151692)).ToObject()}, + {args: wrapArgs(ComplexType, "23"), want: NewComplex(complex(23, 0)).ToObject()}, + {args: wrapArgs(ComplexType, "-516"), want: NewComplex(complex(-516, 0)).ToObject()}, + {args: wrapArgs(ComplexType, "1.003e4"), want: NewComplex(complex(10030, 0)).ToObject()}, + {args: wrapArgs(ComplexType, "151.7"), want: NewComplex(complex(151.7, 0)).ToObject()}, + {args: wrapArgs(ComplexType, "-74.02"), want: NewComplex(complex(-74.02, 0)).ToObject()}, + {args: wrapArgs(ComplexType, "+38.29"), want: NewComplex(complex(38.29, 0)).ToObject()}, + {args: wrapArgs(ComplexType, "8j"), want: NewComplex(complex(0, 8)).ToObject()}, + {args: wrapArgs(ComplexType, "-17j"), want: NewComplex(complex(0, -17)).ToObject()}, + {args: wrapArgs(ComplexType, "7.3j"), want: NewComplex(complex(0, 7.3)).ToObject()}, + {args: wrapArgs(ComplexType, "-4.786j"), want: NewComplex(complex(0, -4.786)).ToObject()}, + {args: wrapArgs(ComplexType, "+17.59123j"), want: NewComplex(complex(0, 17.59123)).ToObject()}, + {args: wrapArgs(ComplexType, "-3.0007e3j"), want: NewComplex(complex(0, -3000.7)).ToObject()}, + {args: wrapArgs(ComplexType, "1+2j"), want: NewComplex(complex(1, 2)).ToObject()}, + {args: wrapArgs(ComplexType, "3.1415-23j"), want: NewComplex(complex(3.1415, -23)).ToObject()}, + {args: wrapArgs(ComplexType, "-23+3.1415j"), want: NewComplex(complex(-23, 3.1415)).ToObject()}, + {args: wrapArgs(ComplexType, "+451.2192+384.27j"), want: NewComplex(complex(451.2192, 384.27)).ToObject()}, + {args: wrapArgs(ComplexType, "-38.378-283.28j"), want: NewComplex(complex(-38.378, -283.28)).ToObject()}, + {args: wrapArgs(ComplexType, "1.76123e2+0.000007e6j"), want: NewComplex(complex(176.123, 7)).ToObject()}, + {args: wrapArgs(ComplexType, "-nan+nanj"), want: NewComplex(complex(math.NaN(), math.NaN())).ToObject()}, + {args: wrapArgs(ComplexType, "inf-infj"), want: NewComplex(complex(math.Inf(1), math.Inf(-1))).ToObject()}, + {args: wrapArgs(ComplexType, 1, 2), want: NewComplex(complex(1, 2)).ToObject()}, + {args: wrapArgs(ComplexType, 7, 23.45), want: NewComplex(complex(7, 23.45)).ToObject()}, + {args: wrapArgs(ComplexType, 28.2537, -19), want: NewComplex(complex(28.2537, -19)).ToObject()}, + {args: wrapArgs(ComplexType, -3.14, -0.685), want: NewComplex(complex(-3.14, -0.685)).ToObject()}, + {args: wrapArgs(ComplexType, -47.234e+2, 2.374e+3), want: NewComplex(complex(-4723.4, 2374)).ToObject()}, + {args: wrapArgs(ComplexType, -4.5, new(big.Int).Neg(biggestFloat)), want: NewComplex(complex(-4.5, -math.MaxFloat64)).ToObject()}, + {args: wrapArgs(ComplexType, biggestFloat, biggestFloat), want: NewComplex(complex(math.MaxFloat64, math.MaxFloat64)).ToObject()}, + {args: wrapArgs(ComplexType, 5, math.NaN()), want: NewComplex(complex(5, math.NaN())).ToObject()}, + {args: wrapArgs(ComplexType, math.Inf(-1), -95), want: NewComplex(complex(math.Inf(-1), -95)).ToObject()}, + {args: wrapArgs(ComplexType, math.NaN(), math.NaN()), want: NewComplex(complex(math.NaN(), math.NaN())).ToObject()}, + {args: wrapArgs(ComplexType, math.Inf(1), math.Inf(-1)), want: NewComplex(complex(math.Inf(1), math.Inf(-1))).ToObject()}, + {args: wrapArgs(ComplexType, complex(-48.8, 0.7395), 5.448), want: NewComplex(complex(-48.8, 6.1875)).ToObject()}, + {args: wrapArgs(ComplexType, -3.14, complex(-4.5, -0.618)), want: NewComplex(complex(-2.5220000000000002, -4.5)).ToObject()}, + {args: wrapArgs(ComplexType, complex(1, 2), complex(3, 4)), want: NewComplex(complex(-3, 5)).ToObject()}, + {args: wrapArgs(ComplexType, complex(-2.47, 0.205e+2), complex(3.1, -0.4)), want: NewComplex(complex(-2.0700000000000003, 23.6)).ToObject()}, + {args: wrapArgs(ComplexType, "bar", 1.2), wantExc: mustCreateException(TypeErrorType, "complex() can't take second arg if first is a string")}, + {args: wrapArgs(ComplexType, "bar", None), wantExc: mustCreateException(TypeErrorType, "complex() can't take second arg if first is a string")}, + {args: wrapArgs(ComplexType, 1.2, "baz"), wantExc: mustCreateException(TypeErrorType, "complex() second arg can't be a string")}, + {args: wrapArgs(ComplexType, None, "baz"), wantExc: mustCreateException(TypeErrorType, "complex() second arg can't be a string")}, + {args: wrapArgs(ComplexType, newObject(goodSlotType)), want: NewComplex(complex(1, 2)).ToObject()}, + {args: wrapArgs(ComplexType, newObject(badSlotType)), wantExc: mustCreateException(TypeErrorType, "__complex__ returned non-complex (type object)")}, + {args: wrapArgs(ComplexType, newObject(slotSubTypeType)), want: subTypeObject}, + {args: wrapArgs(strictEqType, 3.14), want: newStrictEq(3.14)}, + {args: wrapArgs(strictEqType, newObject(goodSlotType)), want: newStrictEq(complex(1, 2))}, + {args: wrapArgs(strictEqType, newObject(badSlotType)), wantExc: mustCreateException(TypeErrorType, "__complex__ returned non-complex (type object)")}, + {args: wrapArgs(), wantExc: mustCreateException(TypeErrorType, "'__new__' requires 1 arguments")}, + {args: wrapArgs(FloatType), wantExc: mustCreateException(TypeErrorType, "complex.__new__(float): float is not a subtype of complex")}, + {args: wrapArgs(ComplexType, None), wantExc: mustCreateException(TypeErrorType, "complex() argument must be a string or a number")}, + {args: wrapArgs(ComplexType, "foo"), wantExc: mustCreateException(ValueErrorType, "complex() arg is a malformed string")}, + {args: wrapArgs(ComplexType, 123, None, None), wantExc: mustCreateException(TypeErrorType, "'__new__' of 'complex' requires at most 2 arguments")}, + } + for _, cas := range cases { + switch got, match := checkInvokeResult(complexNew, cas.args, cas.want, cas.wantExc); match { + case checkInvokeResultExceptionMismatch: + t.Errorf("complex.__new__%v raised %v, want %v", cas.args, got, cas.wantExc) + case checkInvokeResultReturnValueMismatch: + // Handle NaN specially, since NaN != NaN. + if got == nil || cas.want == nil || !got.isInstance(ComplexType) || !cas.want.isInstance(ComplexType) || + !cmplx.IsNaN(toComplexUnsafe(got).Value()) || !cmplx.IsNaN(toComplexUnsafe(cas.want).Value()) { + t.Errorf("complex.__new__%v = %v, want %v", cas.args, got, cas.want) + } + } + } +} + +func TestComplexNonZero(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(complex(0, 0)), want: False.ToObject()}, + {args: wrapArgs(complex(.0, .0)), want: False.ToObject()}, + {args: wrapArgs(complex(0.0, 0.1)), want: True.ToObject()}, + {args: wrapArgs(complex(1, 0)), want: True.ToObject()}, + {args: wrapArgs(complex(3.14, -0.001e+5)), want: True.ToObject()}, + {args: wrapArgs(complex(math.NaN(), math.NaN())), want: True.ToObject()}, + {args: wrapArgs(complex(math.Inf(-1), math.Inf(1))), want: True.ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(complexNonZero), &cas); err != "" { + t.Error(err) + } + } +} + +func TestComplexPos(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(complex(0, 0)), want: NewComplex(complex(0, 0)).ToObject()}, + {args: wrapArgs(complex(42, -0.1)), want: NewComplex(complex(42, -0.1)).ToObject()}, + {args: wrapArgs(complex(-1.2, 375E+2)), want: NewComplex(complex(-1.2, 37500)).ToObject()}, + {args: wrapArgs(complex(5, math.NaN())), want: NewComplex(complex(5, math.NaN())).ToObject()}, + {args: wrapArgs(complex(math.Inf(1), 0.618)), want: NewComplex(complex(math.Inf(1), 0.618)).ToObject()}, + } + for _, cas := range cases { + switch got, match := checkInvokeResult(wrapFuncForTest(complexPos), cas.args, cas.want, cas.wantExc); match { + case checkInvokeResultReturnValueMismatch: + if got == nil || cas.want == nil || !got.isInstance(ComplexType) || !cas.want.isInstance(ComplexType) || + !complexesAreSame(toComplexUnsafe(got).Value(), toComplexUnsafe(cas.want).Value()) { + t.Errorf("complex.__pos__%v = %v, want %v", cas.args, got, cas.want) + } + } + } +} + +func TestComplexRepr(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(complex(0.0, 0.0)), want: NewStr("0j").ToObject()}, + {args: wrapArgs(complex(0.0, 1.0)), want: NewStr("1j").ToObject()}, + {args: wrapArgs(complex(1.0, 2.0)), want: NewStr("(1+2j)").ToObject()}, + {args: wrapArgs(complex(3.1, -4.2)), want: NewStr("(3.1-4.2j)").ToObject()}, + {args: wrapArgs(complex(math.NaN(), math.NaN())), want: NewStr("(nan+nanj)").ToObject()}, + {args: wrapArgs(complex(math.Inf(-1), math.Inf(1))), want: NewStr("(-inf+infj)").ToObject()}, + {args: wrapArgs(complex(math.Inf(1), math.Inf(-1))), want: NewStr("(inf-infj)").ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(Repr), &cas); err != "" { + t.Error(err) + } + } +} + +func TestParseComplex(t *testing.T) { + var ErrSyntax = errors.New("invalid syntax") + cases := []struct { + s string + want complex128 + err error + }{ + {"5", complex(5, 0), nil}, + {"-3.14", complex(-3.14, 0), nil}, + {"1.8456e3", complex(1845.6, 0), nil}, + {"23j", complex(0, 23), nil}, + {"7j", complex(0, 7), nil}, + {"-365.12j", complex(0, -365.12), nil}, + {"1+2j", complex(1, 2), nil}, + {"-.3+.7j", complex(-0.3, 0.7), nil}, + {"-1.3+2.7j", complex(-1.3, 2.7), nil}, + {"48.39-20.3j", complex(48.39, -20.3), nil}, + {"-1.23e2-30.303j", complex(-123, -30.303), nil}, + {"-1.23e2-45.678e1j", complex(-123, -456.78), nil}, + {"nan+nanj", complex(math.NaN(), math.NaN()), nil}, + {"nan-nanj", complex(math.NaN(), math.NaN()), nil}, + {"-nan-nanj", complex(math.NaN(), math.NaN()), nil}, + {"inf+infj", complex(math.Inf(1), math.Inf(1)), nil}, + {"inf-infj", complex(math.Inf(1), math.Inf(-1)), nil}, + {"-inf-infj", complex(math.Inf(-1), math.Inf(-1)), nil}, + {"infINIty+infinityj", complex(math.Inf(1), math.Inf(1)), nil}, + {"3.4+j", complex(3.4, 1), nil}, + {"21.98-j", complex(21.98, -1), nil}, + {"+j", complex(0, 1), nil}, + {"-j", complex(0, -1), nil}, + {"j", complex(0, 1), nil}, + {"(2.1-3.4j)", complex(2.1, -3.4), nil}, + {" (2.1-3.4j) ", complex(2.1, -3.4), nil}, + {" ( 2.1-3.4j ) ", complex(2.1, -3.4), nil}, + {" \t \n \r ( \t \n \r 2.1-3.4j \t \n \r ) \t \n \r ", complex(2.1, -3.4), nil}, + {" 3.14-15.16j ", complex(3.14, -15.16), nil}, + {"(2.1-3.4j", complex(0, 0), ErrSyntax}, + {"((2.1-3.4j))", complex(0, 0), ErrSyntax}, + {"3.14 -15.16j", complex(0, 0), ErrSyntax}, + {"3.14- 15.16j", complex(0, 0), ErrSyntax}, + {"3.14-15.16 j", complex(0, 0), ErrSyntax}, + {"3.14 - 15.16 j", complex(0, 0), ErrSyntax}, + {"foo", complex(0, 0), ErrSyntax}, + {"foo+bar", complex(0, 0), ErrSyntax}, + } + for _, cas := range cases { + if got, _ := parseComplex(cas.s); !complexesAreSame(got, cas.want) { + t.Errorf("parseComplex(%q) = %g, want %g", cas.s, got, cas.want) + } + } +} + +func TestComplexHash(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(complex(0.0, 0.0)), want: NewInt(0).ToObject()}, + {args: wrapArgs(complex(0.0, 1.0)), want: NewInt(1000003).ToObject()}, + {args: wrapArgs(complex(1.0, 0.0)), want: NewInt(1).ToObject()}, + {args: wrapArgs(complex(3.1, -4.2)), want: NewInt(-1556830019620134).ToObject()}, + {args: wrapArgs(complex(3.1, 4.2)), want: NewInt(1557030815934348).ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(complexHash), &cas); err != "" { + t.Error(err) + } + } +} + +func floatsAreSame(a, b float64) bool { + return a == b || (math.IsNaN(a) && math.IsNaN(b)) +} + +func complexesAreSame(a, b complex128) bool { + return floatsAreSame(real(a), real(b)) && floatsAreSame(imag(a), imag(b)) +} + +func tupleComplexesAreSame(got, want *Object) bool { + if toTupleUnsafe(got).Len() != toTupleUnsafe(want).Len() { + return false + } + for i := 0; i < toTupleUnsafe(got).Len(); i++ { + if !complexesAreSame(toComplexUnsafe(toTupleUnsafe(got).GetItem(i)).Value(), toComplexUnsafe(toTupleUnsafe(want).GetItem(i)).Value()) { + return false + } + } + return true +} diff --git a/runtime/core.go b/runtime/core.go index 2c814f2a..4cc20f66 100644 --- a/runtime/core.go +++ b/runtime/core.go @@ -16,13 +16,17 @@ package grumpy import ( "fmt" - "io" "log" - "os" "reflect" + "sync/atomic" ) -var logFatal = func(msg string) { log.Fatal(msg) } +var ( + logFatal = func(msg string) { log.Fatal(msg) } + // ThreadCount is the number of goroutines started with StartThread that + // have not yet joined. + ThreadCount int64 +) // Abs returns the result of o.__abs__ and is equivalent to the Python // expression "abs(o)". @@ -161,6 +165,12 @@ func Div(f *Frame, v, w *Object) (*Object, *BaseException) { return binaryOp(f, v, w, v.typ.slots.Div, v.typ.slots.RDiv, w.typ.slots.RDiv, "/") } +// DivMod returns the result (quotient and remainder tuple) of dividing v by w +// according to the __divmod/rdivmod__ operator. +func DivMod(f *Frame, v, w *Object) (*Object, *BaseException) { + return binaryOp(f, v, w, v.typ.slots.DivMod, v.typ.slots.RDivMod, w.typ.slots.RDivMod, "divmod()") +} + // Eq returns the equality of v and w according to the __eq__ operator. func Eq(f *Frame, v, w *Object) (*Object, *BaseException) { r, raised := compareRich(f, compareOpEq, v, w) @@ -173,17 +183,39 @@ func Eq(f *Frame, v, w *Object) (*Object, *BaseException) { return GetBool(compareDefault(f, v, w) == 0).ToObject(), nil } -// FormatException returns a single-line exception string for the given -// exception object, e.g. "NameError: name 'x' is not defined\n". -func FormatException(f *Frame, e *BaseException) (string, *BaseException) { - s, raised := ToStr(f, e.ToObject()) +// FloorDiv returns the equality of v and w according to the __floordiv/rfloordiv__ operator. +func FloorDiv(f *Frame, v, w *Object) (*Object, *BaseException) { + return binaryOp(f, v, w, v.typ.slots.FloorDiv, v.typ.slots.RFloorDiv, w.typ.slots.RFloorDiv, "//") +} + +// FormatExc calls traceback.format_exc, falling back to the single line +// exception message if that fails, e.g. "NameError: name 'x' is not defined\n". +func FormatExc(f *Frame) (s string) { + exc, tb := f.ExcInfo() + defer func() { + if s == "" { + strResult, raised := ToStr(f, exc.ToObject()) + if raised == nil && strResult.Value() != "" { + s = fmt.Sprintf("%s: %s\n", exc.typ.Name(), strResult.Value()) + } else { + s = exc.typ.Name() + "\n" + } + } + f.RestoreExc(exc, tb) + }() + tbMod, raised := SysModules.GetItemString(f, "traceback") + if raised != nil || tbMod == nil { + return + } + formatExc, raised := GetAttr(f, tbMod, NewStr("format_exc"), nil) if raised != nil { - return "", raised + return } - if len(s.Value()) == 0 { - return e.typ.Name() + "\n", nil + result, raised := formatExc.Call(f, nil, nil) + if raised != nil || !result.isInstance(StrType) { + return } - return fmt.Sprintf("%s: %s\n", e.typ.Name(), s.Value()), nil + return toStrUnsafe(result).Value() } // GE returns the result of operation v >= w. @@ -253,6 +285,23 @@ func Hash(f *Frame, o *Object) (*Int, *BaseException) { return toIntUnsafe(h), nil } +// Hex returns the result of o.__hex__ if defined. +func Hex(f *Frame, o *Object) (*Object, *BaseException) { + hex := o.typ.slots.Hex + if hex == nil { + raised := f.RaiseType(TypeErrorType, "hex() argument can't be converted to hex") + return nil, raised + } + h, raised := hex.Fn(f, o) + if raised != nil { + return nil, raised + } + if !h.isInstance(StrType) { + return nil, f.RaiseType(TypeErrorType, fmt.Sprintf("__hex__ returned non-string (type %s)", h.typ.name)) + } + return h, nil +} + // IAdd returns the result of v.__iadd__ if defined, otherwise falls back to // Add. func IAdd(f *Frame, v, w *Object) (*Object, *BaseException) { @@ -271,6 +320,18 @@ func IDiv(f *Frame, v, w *Object) (*Object, *BaseException) { return inplaceOp(f, v, w, v.typ.slots.IDiv, Div) } +// IFloorDiv returns the result of v.__ifloordiv__ if defined, otherwise falls back to +// floordiv. +func IFloorDiv(f *Frame, v, w *Object) (*Object, *BaseException) { + return inplaceOp(f, v, w, v.typ.slots.IFloorDiv, FloorDiv) +} + +// ILShift returns the result of v.__ilshift__ if defined, otherwise falls back +// to lshift. +func ILShift(f *Frame, v, w *Object) (*Object, *BaseException) { + return inplaceOp(f, v, w, v.typ.slots.ILShift, LShift) +} + // IMod returns the result of v.__imod__ if defined, otherwise falls back to // mod. func IMod(f *Frame, v, w *Object) (*Object, *BaseException) { @@ -298,6 +359,17 @@ func IOr(f *Frame, v, w *Object) (*Object, *BaseException) { return inplaceOp(f, v, w, v.typ.slots.IOr, Or) } +// IPow returns the result of v.__pow__ if defined, otherwise falls back to IPow. +func IPow(f *Frame, v, w *Object) (*Object, *BaseException) { + return inplaceOp(f, v, w, v.typ.slots.IPow, Pow) +} + +// IRShift returns the result of v.__irshift__ if defined, otherwise falls back +// to rshift. +func IRShift(f *Frame, v, w *Object) (*Object, *BaseException) { + return inplaceOp(f, v, w, v.typ.slots.IRShift, RShift) +} + // IsInstance returns true if the type o is an instance of classinfo, or an // instance of an element in classinfo (if classinfo is a tuple). It returns // false otherwise. The argument classinfo must be a type or a tuple whose @@ -596,6 +668,33 @@ func Next(f *Frame, iter *Object) (*Object, *BaseException) { return next.Fn(f, iter) } +// Oct returns the result of o.__oct__ if defined. +func Oct(f *Frame, o *Object) (*Object, *BaseException) { + oct := o.typ.slots.Oct + if oct == nil { + raised := f.RaiseType(TypeErrorType, "oct() argument can't be converted to oct") + return nil, raised + } + o, raised := oct.Fn(f, o) + if raised != nil { + return nil, raised + } + if !o.isInstance(StrType) { + return nil, f.RaiseType(TypeErrorType, fmt.Sprintf("__oct__ returned non-string (type %s)", o.typ.name)) + } + return o, nil +} + +// Pos returns the result of o.__pos__ and is equivalent to the Python +// expression "+o". +func Pos(f *Frame, o *Object) (*Object, *BaseException) { + pos := o.typ.slots.Pos + if pos == nil { + return nil, f.RaiseType(TypeErrorType, fmt.Sprintf("bad operand type for unary +: '%s'", o.typ.Name())) + } + return pos.Fn(f, o) +} + // Print implements the Python print statement. It calls str() on the given args // and outputs the results to stdout separated by spaces. Similar to the Python // print statement. @@ -607,7 +706,7 @@ func Print(f *Frame, args Args, nl bool) *BaseException { } else if len(args) > 0 { end = " " } - return pyPrint(f, args, " ", end, os.Stdout) + return pyPrint(f, args, " ", end, Stdout) } // Repr returns a string containing a printable representation of o. This is @@ -704,14 +803,12 @@ func SetItem(f *Frame, o, key, value *Object) *BaseException { // StartThread runs callable in a new goroutine. func StartThread(callable *Object) { go func() { + atomic.AddInt64(&ThreadCount, 1) + defer atomic.AddInt64(&ThreadCount, -1) f := NewRootFrame() _, raised := callable.Call(f, nil, nil) if raised != nil { - s, raised := FormatException(f, raised) - if raised != nil { - s = raised.String() - } - fmt.Fprintf(os.Stderr, s) + Stderr.writeString(FormatExc(f)) } }() } @@ -1182,17 +1279,30 @@ func hashNotImplemented(f *Frame, o *Object) (*Object, *BaseException) { } // pyPrint encapsulates the logic of the Python print function. -func pyPrint(f *Frame, args Args, sep, end string, file io.Writer) *BaseException { +func pyPrint(f *Frame, args Args, sep, end string, file *File) *BaseException { for i, arg := range args { if i > 0 { - fmt.Fprint(file, sep) + err := file.writeString(sep) + if err != nil { + return f.RaiseType(IOErrorType, err.Error()) + } } + s, raised := ToStr(f, arg) if raised != nil { return raised } - fmt.Fprint(file, s.Value()) + + err := file.writeString(s.Value()) + if err != nil { + return f.RaiseType(IOErrorType, err.Error()) + } } - fmt.Fprint(file, end) + + err := file.writeString(end) + if err != nil { + return f.RaiseType(IOErrorType, err.Error()) + } + return nil } diff --git a/runtime/core_test.go b/runtime/core_test.go index 4c0fdb70..5a451f5d 100644 --- a/runtime/core_test.go +++ b/runtime/core_test.go @@ -15,7 +15,6 @@ package grumpy import ( - "bytes" "fmt" "math/big" "reflect" @@ -84,6 +83,9 @@ func TestBinaryOps(t *testing.T) { "__idiv__": newBuiltinFunction("__idiv__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { return args[1], nil }).ToObject(), + "__ilshift__": newBuiltinFunction("__ilshift__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + return args[1], nil + }).ToObject(), "__imod__": newBuiltinFunction("__imod__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { return args[1], nil }).ToObject(), @@ -93,6 +95,9 @@ func TestBinaryOps(t *testing.T) { "__ior__": newBuiltinFunction("__ior__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { return args[1], nil }).ToObject(), + "__irshift__": newBuiltinFunction("__irshift__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + return args[1], nil + }).ToObject(), "__isub__": newBuiltinFunction("__isub__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { return args[1], nil }).ToObject(), @@ -122,6 +127,7 @@ func TestBinaryOps(t *testing.T) { {IAnd, newObject(ObjectType), newObject(fooType), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for &: 'object' and 'Foo'")}, {IDiv, NewInt(123).ToObject(), newObject(bazType), NewStr("123").ToObject(), nil}, {IDiv, newObject(inplaceType), NewInt(42).ToObject(), NewInt(42).ToObject(), nil}, + {ILShift, newObject(inplaceType), NewInt(123).ToObject(), NewInt(123).ToObject(), nil}, {IMod, NewInt(24).ToObject(), NewInt(6).ToObject(), NewInt(0).ToObject(), nil}, {IMod, newObject(inplaceType), NewFloat(3.14).ToObject(), NewFloat(3.14).ToObject(), nil}, {IMul, NewStr("foo").ToObject(), NewInt(3).ToObject(), NewStr("foofoofoo").ToObject(), nil}, @@ -130,6 +136,7 @@ func TestBinaryOps(t *testing.T) { {IOr, newObject(inplaceType), NewInt(42).ToObject(), NewInt(42).ToObject(), nil}, {IOr, NewInt(9).ToObject(), NewInt(12).ToObject(), NewInt(13).ToObject(), nil}, {IOr, newObject(ObjectType), newObject(fooType), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for |: 'object' and 'Foo'")}, + {IRShift, newObject(inplaceType), NewInt(123).ToObject(), NewInt(123).ToObject(), nil}, {ISub, NewInt(3).ToObject(), NewInt(-3).ToObject(), NewInt(6).ToObject(), nil}, {ISub, newObject(inplaceType), None, None, nil}, {IXor, newObject(inplaceType), None, None, nil}, @@ -320,23 +327,25 @@ func TestDelItem(t *testing.T) { } func TestFormatException(t *testing.T) { - f := NewRootFrame() - cases := []struct { - o *Object - want string - }{ - {mustNotRaise(ExceptionType.Call(f, nil, nil)), "Exception\n"}, - {mustNotRaise(AttributeErrorType.Call(f, wrapArgs(""), nil)), "AttributeError\n"}, - {mustNotRaise(TypeErrorType.Call(f, wrapArgs(123), nil)), "TypeError: 123\n"}, - {mustNotRaise(AttributeErrorType.Call(f, wrapArgs("hello", "there"), nil)), "AttributeError: ('hello', 'there')\n"}, + fun := wrapFuncForTest(func(f *Frame, t *Type, args ...*Object) (string, *BaseException) { + e, raised := t.Call(f, args, nil) + if raised != nil { + return "", raised + } + f.Raise(e, nil, nil) + s := FormatExc(f) + f.RestoreExc(nil, nil) + return s, nil + }) + cases := []invokeTestCase{ + {args: wrapArgs(ExceptionType), want: NewStr("Exception\n").ToObject()}, + {args: wrapArgs(AttributeErrorType, ""), want: NewStr("AttributeError\n").ToObject()}, + {args: wrapArgs(TypeErrorType, 123), want: NewStr("TypeError: 123\n").ToObject()}, + {args: wrapArgs(AttributeErrorType, "hello", "there"), want: NewStr("AttributeError: ('hello', 'there')\n").ToObject()}, } for _, cas := range cases { - if !cas.o.isInstance(BaseExceptionType) { - t.Errorf("expected FormatException() input to be BaseException, got %s", cas.o.typ.Name()) - } else if got, raised := FormatException(f, toBaseExceptionUnsafe(cas.o)); raised != nil { - t.Errorf("FormatException(%v) raised %v, want nil", cas.o, raised) - } else if got != cas.want { - t.Errorf("FormatException(%v) = %q, want %q", cas.o, got, cas.want) + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) } } } @@ -422,6 +431,32 @@ func TestHash(t *testing.T) { } } +func TestHex(t *testing.T) { + badHex := newTestClass("badHex", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__hex__": newBuiltinFunction("__hex__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + return NewInt(123).ToObject(), nil + }).ToObject(), + })) + goodHex := newTestClass("goodHex", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__hex__": newBuiltinFunction("__hex__", func(f *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { + return NewStr("0x123").ToObject(), nil + }).ToObject(), + })) + cases := []invokeTestCase{ + {args: wrapArgs(-123), want: NewStr("-0x7b").ToObject()}, + {args: wrapArgs(123), want: NewStr("0x7b").ToObject()}, + {args: wrapArgs(newObject(goodHex)), want: NewStr("0x123").ToObject()}, + {args: wrapArgs(NewList()), wantExc: mustCreateException(TypeErrorType, "hex() argument can't be converted to hex")}, + {args: wrapArgs(NewDict()), wantExc: mustCreateException(TypeErrorType, "hex() argument can't be converted to hex")}, + {args: wrapArgs(newObject(badHex)), wantExc: mustCreateException(TypeErrorType, "__hex__ returned non-string (type int)")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(Hex), &cas); err != "" { + t.Error(err) + } + } +} + func TestIndex(t *testing.T) { goodType := newTestClass("GoodIndex", []*Type{ObjectType}, newStringDict(map[string]*Object{ "__index__": newBuiltinFunction("__index__", func(f *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { @@ -736,14 +771,57 @@ func TestInvokeKeywordArgs(t *testing.T) { } } +func TestOct(t *testing.T) { + badOct := newTestClass("badOct", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__oct__": newBuiltinFunction("__oct__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + return NewInt(123).ToObject(), nil + }).ToObject(), + })) + goodOct := newTestClass("goodOct", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__oct__": newBuiltinFunction("__oct__", func(f *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { + return NewStr("0123").ToObject(), nil + }).ToObject(), + })) + cases := []invokeTestCase{ + {args: wrapArgs(-123), want: NewStr("-0173").ToObject()}, + {args: wrapArgs(123), want: NewStr("0173").ToObject()}, + {args: wrapArgs(newObject(goodOct)), want: NewStr("0123").ToObject()}, + {args: wrapArgs(NewList()), wantExc: mustCreateException(TypeErrorType, "oct() argument can't be converted to oct")}, + {args: wrapArgs(NewDict()), wantExc: mustCreateException(TypeErrorType, "oct() argument can't be converted to oct")}, + {args: wrapArgs(newObject(badOct)), wantExc: mustCreateException(TypeErrorType, "__oct__ returned non-string (type int)")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(Oct), &cas); err != "" { + t.Error(err) + } + } +} + +func TestPos(t *testing.T) { + pos := newTestClass("pos", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__pos__": newBuiltinFunction("__pos__", func(f *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { + return NewInt(-42).ToObject(), nil + }).ToObject(), + })) + cases := []invokeTestCase{ + {args: wrapArgs(42), want: NewInt(42).ToObject()}, + {args: wrapArgs(1.2), want: NewFloat(1.2).ToObject()}, + {args: wrapArgs(NewLong(big.NewInt(123))), want: NewLong(big.NewInt(123)).ToObject()}, + {args: wrapArgs(newObject(pos)), want: NewInt(-42).ToObject()}, + {args: wrapArgs("foo"), wantExc: mustCreateException(TypeErrorType, "bad operand type for unary +: 'str'")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(Pos), &cas); err != "" { + t.Error(err) + } + } +} + func TestPyPrint(t *testing.T) { fun := wrapFuncForTest(func(f *Frame, args *Tuple, sep, end string) (string, *BaseException) { - var buf bytes.Buffer - raised := pyPrint(NewRootFrame(), args.elems, sep, end, &buf) - if raised != nil { - return "", raised - } - return buf.String(), nil + return captureStdout(f, func() *BaseException { + return pyPrint(NewRootFrame(), args.elems, sep, end, Stdout) + }) }) cases := []invokeTestCase{ {args: wrapArgs(NewTuple(), "", "\n"), want: NewStr("\n").ToObject()}, @@ -758,7 +836,8 @@ func TestPyPrint(t *testing.T) { } } -func TestPrint(t *testing.T) { +// TODO(corona10): Re-enable once #282 is addressed. +/*func TestPrint(t *testing.T) { fun := wrapFuncForTest(func(f *Frame, args *Tuple, nl bool) (string, *BaseException) { return captureStdout(f, func() *BaseException { return Print(NewRootFrame(), args.elems, nl) @@ -775,7 +854,7 @@ func TestPrint(t *testing.T) { t.Error(err) } } -} +}*/ func TestReprRaise(t *testing.T) { testTypes := []*Type{ @@ -1087,6 +1166,20 @@ func TestToNative(t *testing.T) { } } +func BenchmarkGetAttr(b *testing.B) { + f := NewRootFrame() + attr := NewStr("bar") + fooType := newTestClass("Foo", []*Type{ObjectType}, NewDict()) + foo := newObject(fooType) + if raised := SetAttr(f, foo, attr, NewInt(123).ToObject()); raised != nil { + panic(raised) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + mustNotRaise(GetAttr(f, foo, attr, nil)) + } +} + // SetAttr is tested in TestObjectSetAttr. func exceptionsAreEquivalent(e1 *BaseException, e2 *BaseException) bool { diff --git a/runtime/descriptor.go b/runtime/descriptor.go index 32825ef4..ced46010 100644 --- a/runtime/descriptor.go +++ b/runtime/descriptor.go @@ -19,6 +19,13 @@ import ( "reflect" ) +type fieldDescriptorType int + +const ( + fieldDescriptorRO fieldDescriptorType = iota + fieldDescriptorRW +) + // Property represents Python 'property' objects. type Property struct { Object @@ -98,24 +105,57 @@ func propertySet(f *Frame, desc, inst, value *Object) *BaseException { // makeStructFieldDescriptor creates a descriptor with a getter that returns // the field given by fieldName from t's basis structure. -func makeStructFieldDescriptor(t *Type, fieldName, propertyName string) *Object { +func makeStructFieldDescriptor(t *Type, fieldName, propertyName string, fieldMode fieldDescriptorType) *Object { field, ok := t.basis.FieldByName(fieldName) if !ok { logFatal(fmt.Sprintf("no such field %q for basis %s", fieldName, nativeTypeName(t.basis))) } + getterFunc := func(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { - var ret *Object - var raised *BaseException - if raised = checkFunctionArgs(f, fieldName, args, ObjectType); raised == nil { - o := args[0] - if !o.isInstance(t) { + if raised := checkFunctionArgs(f, fieldName, args, ObjectType); raised != nil { + return nil, raised + } + + self := args[0] + if !self.isInstance(t) { + format := "descriptor '%s' for '%s' objects doesn't apply to '%s' objects" + return nil, f.RaiseType(TypeErrorType, fmt.Sprintf(format, propertyName, t.Name(), self.typ.Name())) + } + + return WrapNative(f, t.slots.Basis.Fn(self).FieldByIndex(field.Index)) + } + getter := newBuiltinFunction("_get"+fieldName, getterFunc).ToObject() + + setter := None + if fieldMode == fieldDescriptorRW { + if field.PkgPath != "" { + logFatal(fmt.Sprintf("field '%q' is not public on Golang code. Please fix it.", fieldName)) + } + + setterFunc := func(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + if raised := checkFunctionArgs(f, fieldName, args, ObjectType, ObjectType); raised != nil { + return nil, raised + } + + self := args[0] + newValue := args[1] + + if !self.isInstance(t) { format := "descriptor '%s' for '%s' objects doesn't apply to '%s' objects" - raised = f.RaiseType(TypeErrorType, fmt.Sprintf(format, propertyName, t.Name(), o.typ.Name())) - } else { - ret, raised = WrapNative(f, t.slots.Basis.Fn(o).FieldByIndex(field.Index)) + return nil, f.RaiseType(TypeErrorType, fmt.Sprintf(format, propertyName, t.Name(), self.typ.Name())) } + + val := t.slots.Basis.Fn(self).FieldByIndex(field.Index) + converted, raised := maybeConvertValue(f, newValue, field.Type) + if raised != nil { + return nil, raised + } + + val.Set(converted) + return None, nil } - return ret, raised + + setter = newBuiltinFunction("_set"+fieldName, setterFunc).ToObject() } - return newProperty(newBuiltinFunction("_get"+fieldName, getterFunc).ToObject(), None, None).ToObject() + return newProperty(getter, setter, None).ToObject() } diff --git a/runtime/descriptor_test.go b/runtime/descriptor_test.go index a9ae25d3..0d87f42c 100644 --- a/runtime/descriptor_test.go +++ b/runtime/descriptor_test.go @@ -90,7 +90,7 @@ func TestMakeStructFieldDescriptor(t *testing.T) { return nil, raised } t := toTypeUnsafe(args[0]) - desc := makeStructFieldDescriptor(t, toStrUnsafe(args[1]).Value(), toStrUnsafe(args[2]).Value()) + desc := makeStructFieldDescriptor(t, toStrUnsafe(args[1]).Value(), toStrUnsafe(args[2]).Value(), fieldDescriptorRO) get, raised := GetAttr(f, desc, NewStr("__get__"), nil) if raised != nil { return nil, raised @@ -110,3 +110,53 @@ func TestMakeStructFieldDescriptor(t *testing.T) { } } } + +func TestMakeStructFieldDescriptorRWGet(t *testing.T) { + fun := newBuiltinFunction("TestMakeStructFieldDescriptorRW_get", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "TestMakeStructFieldDescriptorRW_get", args, TypeType, StrType, StrType, ObjectType); raised != nil { + return nil, raised + } + t := toTypeUnsafe(args[0]) + desc := makeStructFieldDescriptor(t, toStrUnsafe(args[1]).Value(), toStrUnsafe(args[2]).Value(), fieldDescriptorRW) + get, raised := GetAttr(f, desc, NewStr("__get__"), nil) + if raised != nil { + return nil, raised + } + return get.Call(f, wrapArgs(args[3], t), nil) + }).ToObject() + cases := []invokeTestCase{ + {args: wrapArgs(FileType, "Softspace", "softspace", newObject(FileType)), want: NewInt(0).ToObject()}, + {args: wrapArgs(FileType, "Softspace", "softspace", 42), wantExc: mustCreateException(TypeErrorType, "descriptor 'softspace' for 'file' objects doesn't apply to 'int' objects")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } +} + +func TestMakeStructFieldDescriptorRWSet(t *testing.T) { + fun := newBuiltinFunction("TestMakeStructFieldDescriptorRW_set", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "TestMakeStructFieldDescriptorRW_set", args, TypeType, StrType, StrType, ObjectType, ObjectType); raised != nil { + return nil, raised + } + t := toTypeUnsafe(args[0]) + desc := makeStructFieldDescriptor(t, toStrUnsafe(args[1]).Value(), toStrUnsafe(args[2]).Value(), fieldDescriptorRW) + set, raised := GetAttr(f, desc, NewStr("__set__"), nil) + if raised != nil { + return nil, raised + } + return set.Call(f, wrapArgs(args[3], args[4]), nil) + }).ToObject() + cases := []invokeTestCase{ + {args: wrapArgs(FileType, "Softspace", "softspace", newObject(FileType), NewInt(0).ToObject()), want: None}, + {args: wrapArgs(FileType, "Softspace", "softspace", newObject(FileType), NewInt(0)), want: None}, + {args: wrapArgs(FileType, "Softspace", "softspace", newObject(FileType), "wrong"), wantExc: mustCreateException(TypeErrorType, "an int is required")}, + {args: wrapArgs(FileType, "Softspace", "softspace", 42, NewInt(0)), wantExc: mustCreateException(TypeErrorType, "descriptor 'softspace' for 'file' objects doesn't apply to 'int' objects")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } +} diff --git a/runtime/dict.go b/runtime/dict.go index c69b5f11..32eeca43 100644 --- a/runtime/dict.go +++ b/runtime/dict.go @@ -230,7 +230,7 @@ func (iter *dictEntryIterator) next() *dictEntry { // 64bit atomic ops need to be 8 byte aligned. This compile time check // verifies alignment by creating a negative constant for an unsigned type. // See sync/atomic docs for details. - const _ = -(unsafe.Offsetof(iter.index) % 8) + const blank = -(unsafe.Offsetof(iter.index) % 8) index := int(atomic.AddInt64(&iter.index, 1)) - 1 if index >= numEntries { break @@ -310,7 +310,7 @@ func (d *Dict) loadVersion() int64 { // 64bit atomic ops need to be 8 byte aligned. This compile time check // verifies alignment by creating a negative constant for an unsigned type. // See sync/atomic docs for details. - const _ = -(unsafe.Offsetof(d.version) % 8) + const blank = -(unsafe.Offsetof(d.version) % 8) return atomic.LoadInt64(&d.version) } @@ -319,32 +319,18 @@ func (d *Dict) incVersion() { // 64bit atomic ops need to be 8 byte aligned. This compile time check // verifies alignment by creating a negative constant for an unsigned type. // See sync/atomic docs for details. - const _ = -(unsafe.Offsetof(d.version) % 8) + const blank = -(unsafe.Offsetof(d.version) % 8) atomic.AddInt64(&d.version, 1) } // DelItem removes the entry associated with key from d. It returns true if an // item was removed, or false if it did not exist in d. func (d *Dict) DelItem(f *Frame, key *Object) (bool, *BaseException) { - hash, raised := Hash(f, key) + originValue, raised := d.putItem(f, key, nil, true) if raised != nil { return false, raised } - deleted := false - d.mutex.Lock(f) - v := d.version - if index, entry, raised := d.table.lookupEntry(f, hash.Value(), key); raised == nil { - if v != d.version { - raised = f.RaiseType(RuntimeErrorType, "dictionary changed during write") - } else if entry != nil && entry != deletedEntry { - d.table.storeEntry(index, deletedEntry) - d.table.incUsed(-1) - d.incVersion() - deleted = true - } - } - d.mutex.Unlock(f) - return deleted, raised + return originValue != nil, nil } // DelItemString removes the entry associated with key from d. It returns true @@ -376,6 +362,12 @@ func (d *Dict) GetItemString(f *Frame, key string) (*Object, *BaseException) { return d.GetItem(f, NewStr(key).ToObject()) } +// Pop looks up key in d, returning and removing the associalted value if exist, +// or nil if key is not present in d. +func (d *Dict) Pop(f *Frame, key *Object) (*Object, *BaseException) { + return d.putItem(f, key, nil, true) +} + // Keys returns a list containing all the keys in d. func (d *Dict) Keys(f *Frame) *List { d.mutex.Lock(f) @@ -396,43 +388,54 @@ func (d *Dict) Len() int { return d.loadTable().loadUsed() } -// putItem associates value with key in d, returning true if the key was added -// (i.e. it was not already present in d). -func (d *Dict) putItem(f *Frame, key, value *Object) (bool, *BaseException) { +// putItem associates value with key in d, returning the old associated value if +// the key was added, or nil if it was not already present in d. +func (d *Dict) putItem(f *Frame, key, value *Object, overwrite bool) (*Object, *BaseException) { hash, raised := Hash(f, key) if raised != nil { - return false, raised + return nil, raised } d.mutex.Lock(f) t := d.table v := d.version index, entry, raised := t.lookupEntry(f, hash.Value(), key) - added := false + var originValue *Object if raised == nil { if v != d.version { // Dictionary was recursively modified. Blow up instead // of trying to recover. raised = f.RaiseType(RuntimeErrorType, "dictionary changed during write") } else { - if newTable, ok := t.writeEntry(f, index, &dictEntry{hash.Value(), key, value}); ok { - if newTable != nil { - d.storeTable(newTable) + if value == nil { + // Going to delete the entry. + if entry != nil && entry != deletedEntry { + d.table.storeEntry(index, deletedEntry) + d.table.incUsed(-1) + d.incVersion() } - d.incVersion() - // Key absent if entry == nil or deletedEntry. - added = entry == nil || entry == deletedEntry - } else { - raised = f.RaiseType(OverflowErrorType, errResultTooLarge) + } else if overwrite || entry == nil { + newEntry := &dictEntry{hash.Value(), key, value} + if newTable, ok := t.writeEntry(f, index, newEntry); ok { + if newTable != nil { + d.storeTable(newTable) + } + d.incVersion() + } else { + raised = f.RaiseType(OverflowErrorType, errResultTooLarge) + } + } + if entry != nil && entry != deletedEntry { + originValue = entry.value } } } d.mutex.Unlock(f) - return added, raised + return originValue, raised } // SetItem associates value with key in d. func (d *Dict) SetItem(f *Frame, key, value *Object) *BaseException { - _, raised := d.putItem(f, key, value) + _, raised := d.putItem(f, key, value, true) return raised } @@ -536,6 +539,13 @@ func dictContains(f *Frame, seq, value *Object) (*Object, *BaseException) { return GetBool(item != nil).ToObject(), nil } +func dictCopy(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "copy", args, DictType); raised != nil { + return nil, raised + } + return DictType.Call(f, args, nil) +} + func dictDelItem(f *Frame, o, key *Object) *BaseException { deleted, raised := toDictUnsafe(o).DelItem(f, key) if raised != nil { @@ -577,6 +587,13 @@ func dictGet(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { return item, raised } +func dictHasKey(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "has_key", args, DictType, ObjectType); raised != nil { + return nil, raised + } + return dictContains(f, args[0], args[1]) +} + func dictItems(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { if raised := checkMethodArgs(f, "items", args, DictType); raised != nil { return nil, raised @@ -673,7 +690,7 @@ func dictLen(f *Frame, o *Object) (*Object, *BaseException) { } func dictNE(f *Frame, v, w *Object) (*Object, *BaseException) { - if !v.isInstance(DictType) { + if !w.isInstance(DictType) { return NotImplemented, nil } eq, raised := dictsAreEqual(f, toDictUnsafe(v), toDictUnsafe(w)) @@ -689,6 +706,48 @@ func dictNew(f *Frame, t *Type, _ Args, _ KWArgs) (*Object, *BaseException) { return d.ToObject(), nil } +func dictPop(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + expectedTypes := []*Type{DictType, ObjectType, ObjectType} + argc := len(args) + if argc == 2 { + expectedTypes = expectedTypes[:2] + } + if raised := checkMethodArgs(f, "pop", args, expectedTypes...); raised != nil { + return nil, raised + } + key := args[1] + d := toDictUnsafe(args[0]) + item, raised := d.Pop(f, key) + if raised == nil && item == nil { + if argc > 2 { + item = args[2] + } else { + raised = raiseKeyError(f, key) + } + } + return item, raised +} + +func dictPopItem(f *Frame, args Args, _ KWArgs) (item *Object, raised *BaseException) { + if raised := checkMethodArgs(f, "popitem", args, DictType); raised != nil { + return nil, raised + } + d := toDictUnsafe(args[0]) + d.mutex.Lock(f) + iter := newDictEntryIterator(d) + entry := iter.next() + if entry == nil { + raised = f.RaiseType(KeyErrorType, "popitem(): dictionary is empty") + } else { + item = NewTuple(entry.key, entry.value).ToObject() + d.table.storeEntry(int(iter.index-1), deletedEntry) + d.table.incUsed(-1) + d.incVersion() + } + d.mutex.Unlock(f) + return item, raised +} + func dictRepr(f *Frame, o *Object) (*Object, *BaseException) { d := toDictUnsafe(o) if f.reprEnter(d.ToObject()) { @@ -723,6 +782,36 @@ func dictRepr(f *Frame, o *Object) (*Object, *BaseException) { return NewStr(buf.String()).ToObject(), nil } +func dictSetDefault(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + argc := len(args) + if argc == 1 { + return nil, f.RaiseType(TypeErrorType, "setdefault expected at least 1 arguments, got 0") + } + if argc > 3 { + return nil, f.RaiseType(TypeErrorType, fmt.Sprintf("setdefault expected at most 2 arguments, got %v", argc-1)) + } + expectedTypes := []*Type{DictType, ObjectType, ObjectType} + if argc == 2 { + expectedTypes = expectedTypes[:2] + } + if raised := checkMethodArgs(f, "setdefault", args, expectedTypes...); raised != nil { + return nil, raised + } + d := toDictUnsafe(args[0]) + key := args[1] + var value *Object + if argc > 2 { + value = args[2] + } else { + value = None + } + originValue, raised := d.putItem(f, key, value, false) + if originValue != nil { + return originValue, raised + } + return value, raised +} + func dictSetItem(f *Frame, o, key, value *Object) *BaseException { return toDictUnsafe(o).SetItem(f, key, value) } @@ -763,12 +852,17 @@ func dictValues(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { func initDictType(dict map[string]*Object) { dict["clear"] = newBuiltinFunction("clear", dictClear).ToObject() + dict["copy"] = newBuiltinFunction("copy", dictCopy).ToObject() dict["get"] = newBuiltinFunction("get", dictGet).ToObject() + dict["has_key"] = newBuiltinFunction("has_key", dictHasKey).ToObject() dict["items"] = newBuiltinFunction("items", dictItems).ToObject() dict["iteritems"] = newBuiltinFunction("iteritems", dictIterItems).ToObject() dict["iterkeys"] = newBuiltinFunction("iterkeys", dictIterKeys).ToObject() dict["itervalues"] = newBuiltinFunction("itervalues", dictIterValues).ToObject() dict["keys"] = newBuiltinFunction("keys", dictKeys).ToObject() + dict["pop"] = newBuiltinFunction("pop", dictPop).ToObject() + dict["popitem"] = newBuiltinFunction("popitem", dictPopItem).ToObject() + dict["setdefault"] = newBuiltinFunction("setdefault", dictSetDefault).ToObject() dict["update"] = newBuiltinFunction("update", dictUpdate).ToObject() dict["values"] = newBuiltinFunction("values", dictValues).ToObject() DictType.slots.Contains = &binaryOpSlot{dictContains} @@ -819,7 +913,7 @@ func dictItemIteratorNext(f *Frame, o *Object) (ret *Object, raised *BaseExcepti if raised != nil { return nil, raised } - return NewTuple(entry.key, entry.value).ToObject(), nil + return NewTuple2(entry.key, entry.value).ToObject(), nil } func initDictItemIteratorType(map[string]*Object) { diff --git a/runtime/dict_test.go b/runtime/dict_test.go index 84ac2b62..a6545c4d 100644 --- a/runtime/dict_test.go +++ b/runtime/dict_test.go @@ -132,34 +132,21 @@ func TestDictDelItemString(t *testing.T) { } func TestDictEqNE(t *testing.T) { - fun := newBuiltinFunction("TestDictEqNE", func(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { - if raised := checkMethodArgs(f, "TestDictEqNE", args, DictType, DictType, BoolType); raised != nil { - return nil, raised - } - d1, d2 := args[0], args[1] - wantEq := toIntUnsafe(args[2]).IsTrue() - if eq, raised := Eq(f, d1, d2); raised != nil { - return nil, raised - } else if !eq.isInstance(BoolType) || toIntUnsafe(eq).IsTrue() != wantEq { - t.Errorf("Eq(%v, %v) = %v, want %v", d1, d2, eq, GetBool(wantEq)) - } - if eq, raised := Eq(f, d2, d1); raised != nil { + fun := wrapFuncForTest(func(f *Frame, v, w *Object) (*Object, *BaseException) { + eq, raised := Eq(f, v, w) + if raised != nil { return nil, raised - } else if !eq.isInstance(BoolType) || toIntUnsafe(eq).IsTrue() != wantEq { - t.Errorf("Eq(%v, %v) = %v, want %v", d2, d1, eq, GetBool(wantEq)) } - if ne, raised := NE(f, d1, d2); raised != nil { + ne, raised := NE(f, v, w) + if raised != nil { return nil, raised - } else if !ne.isInstance(BoolType) || toIntUnsafe(ne).IsTrue() == wantEq { - t.Errorf("NE(%v, %v) = %v, want %v", d1, d2, ne, GetBool(!wantEq)) } - if ne, raised := NE(f, d2, d1); raised != nil { + valid := GetBool(eq == True.ToObject() && ne == False.ToObject() || eq == False.ToObject() && ne == True.ToObject()).ToObject() + if raised := Assert(f, valid, NewStr("invalid values for __eq__ or __ne__").ToObject()); raised != nil { return nil, raised - } else if !ne.isInstance(BoolType) || toIntUnsafe(ne).IsTrue() == wantEq { - t.Errorf("NE(%v, %v) = %v, want %v", d2, d1, ne, GetBool(!wantEq)) } - return None, nil - }).ToObject() + return eq, nil + }) f := NewRootFrame() large1, large2 := NewDict(), NewDict() largeSize := 100 @@ -177,15 +164,16 @@ func TestDictEqNE(t *testing.T) { } o := newObject(ObjectType) cases := []invokeTestCase{ - {args: wrapArgs(NewDict(), NewDict(), true), want: None}, - {args: wrapArgs(NewDict(), newTestDict("foo", true), false), want: None}, - {args: wrapArgs(newTestDict("foo", "foo"), newTestDict("foo", "foo"), true), want: None}, - {args: wrapArgs(newTestDict("foo", true), newTestDict("bar", true), false), want: None}, - {args: wrapArgs(newTestDict("foo", true), newTestDict("foo", newObject(ObjectType)), false), want: None}, - {args: wrapArgs(newTestDict("foo", true, "bar", false), newTestDict("bar", true), false), want: None}, - {args: wrapArgs(newTestDict("foo", o, "bar", o), newTestDict("foo", o, "bar", o), true), want: None}, - {args: wrapArgs(newTestDict(2, None, "foo", o), newTestDict("foo", o, 2, None), true), want: None}, - {args: wrapArgs(large1, large2, true), want: None}, + {args: wrapArgs(NewDict(), NewDict()), want: True.ToObject()}, + {args: wrapArgs(NewDict(), newTestDict("foo", true)), want: False.ToObject()}, + {args: wrapArgs(newTestDict("foo", "foo"), newTestDict("foo", "foo")), want: True.ToObject()}, + {args: wrapArgs(newTestDict("foo", true), newTestDict("bar", true)), want: False.ToObject()}, + {args: wrapArgs(newTestDict("foo", true), newTestDict("foo", newObject(ObjectType))), want: False.ToObject()}, + {args: wrapArgs(newTestDict("foo", true, "bar", false), newTestDict("bar", true)), want: False.ToObject()}, + {args: wrapArgs(newTestDict("foo", o, "bar", o), newTestDict("foo", o, "bar", o)), want: True.ToObject()}, + {args: wrapArgs(newTestDict(2, None, "foo", o), newTestDict("foo", o, 2, None)), want: True.ToObject()}, + {args: wrapArgs(large1, large2), want: True.ToObject()}, + {args: wrapArgs(NewDict(), 123), want: False.ToObject()}, } for _, cas := range cases { if err := runInvokeTestCase(fun, &cas); err != "" { @@ -409,6 +397,19 @@ func TestDictGetItemString(t *testing.T) { } } +func TestDictHasKey(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(NewDict(), "foo"), want: False.ToObject()}, + {args: wrapArgs(newTestDict("foo", 1, "bar", 2), "foo"), want: True.ToObject()}, + {args: wrapArgs(newTestDict(3, "foo", "bar", 42), 42), want: False.ToObject()}, + } + for _, cas := range cases { + if err := runInvokeMethodTestCase(DictType, "has_key", &cas); err != "" { + t.Error(err) + } + } +} + func TestDictItemIteratorIter(t *testing.T) { iter := &newDictItemIterator(NewDict()).Object cas := &invokeTestCase{args: wrapArgs(iter), want: iter} @@ -571,6 +572,50 @@ func TestDictKeys(t *testing.T) { } } +func TestDictPop(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(newTestDict("foo", 42), "foo"), want: NewInt(42).ToObject()}, + {args: wrapArgs(NewDict(), "foo", 42), want: NewInt(42).ToObject()}, + {args: wrapArgs(NewDict(), "foo"), wantExc: mustCreateException(KeyErrorType, "foo")}, + } + for _, cas := range cases { + if err := runInvokeMethodTestCase(DictType, "pop", &cas); err != "" { + t.Error(err) + } + } +} + +func TestDictPopItem(t *testing.T) { + popItem := mustNotRaise(GetAttr(NewRootFrame(), DictType.ToObject(), NewStr("popitem"), nil)) + fun := wrapFuncForTest(func(f *Frame, d *Dict) (*Object, *BaseException) { + result := NewDict() + item, raised := popItem.Call(f, wrapArgs(d), nil) + for ; raised == nil; item, raised = popItem.Call(f, wrapArgs(d), nil) { + t := toTupleUnsafe(item) + result.SetItem(f, t.GetItem(0), t.GetItem(1)) + } + if raised != nil { + if !raised.isInstance(KeyErrorType) { + return nil, raised + } + f.RestoreExc(nil, nil) + } + if raised = Assert(f, GetBool(d.Len() == 0).ToObject(), nil); raised != nil { + return nil, raised + } + return result.ToObject(), nil + }) + cases := []invokeTestCase{ + {args: wrapArgs(newTestDict("foo", 42)), want: newTestDict("foo", 42).ToObject()}, + {args: wrapArgs(newTestDict("foo", 42, 123, "bar")), want: newTestDict("foo", 42, 123, "bar").ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } +} + func TestDictNewInit(t *testing.T) { cases := []invokeTestCase{ {args: wrapArgs(), want: NewDict().ToObject()}, @@ -601,6 +646,30 @@ func TestDictNewRaises(t *testing.T) { } } +func TestDictSetDefault(t *testing.T) { + setDefaultMethod := mustNotRaise(GetAttr(NewRootFrame(), DictType.ToObject(), NewStr("setdefault"), nil)) + setDefault := newBuiltinFunction("TestDictSetDefault", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + i, raised := setDefaultMethod.Call(f, args, kwargs) + if raised != nil { + return nil, raised + } + return NewTuple(i, args[0]).ToObject(), nil + }).ToObject() + cases := []invokeTestCase{ + {args: wrapArgs(NewDict(), "foo"), want: newTestTuple(None, newTestDict("foo", None)).ToObject()}, + {args: wrapArgs(NewDict(), "foo", 42), want: newTestTuple(42, newTestDict("foo", 42)).ToObject()}, + {args: wrapArgs(newTestDict("foo", 42), "foo"), want: newTestTuple(42, newTestDict("foo", 42)).ToObject()}, + {args: wrapArgs(newTestDict("foo", 42), "foo", 43), want: newTestTuple(42, newTestDict("foo", 42)).ToObject()}, + {args: wrapArgs(NewDict()), wantExc: mustCreateException(TypeErrorType, "setdefault expected at least 1 arguments, got 0")}, + {args: wrapArgs(NewDict(), "foo", "bar", "baz"), wantExc: mustCreateException(TypeErrorType, "setdefault expected at most 2 arguments, got 3")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(setDefault, &cas); err != "" { + t.Error(err) + } + } +} + func TestDictSetItem(t *testing.T) { setItem := newBuiltinFunction("TestDictSetItem", func(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { if raised := checkFunctionArgs(f, "TestDictSetItem", args, DictType, ObjectType, ObjectType); raised != nil { diff --git a/runtime/exceptions.go b/runtime/exceptions.go index 0db7ec89..ee0481a6 100644 --- a/runtime/exceptions.go +++ b/runtime/exceptions.go @@ -28,6 +28,8 @@ var ( // EnvironmentErrorType corresponds to the Python type // 'EnvironmentError'. EnvironmentErrorType = newSimpleType("EnvironmentError", StandardErrorType) + // EOFErrorType corresponds to the Python type 'EOFError'. + EOFErrorType = newSimpleType("EOFError", StandardErrorType) // ExceptionType corresponds to the Python type 'Exception'. ExceptionType = newSimpleType("Exception", BaseExceptionType) // FutureWarningType corresponds to the Python type 'FutureWarning'. @@ -40,6 +42,8 @@ var ( IndexErrorType = newSimpleType("IndexError", LookupErrorType) // IOErrorType corresponds to the Python type 'IOError'. IOErrorType = newSimpleType("IOError", EnvironmentErrorType) + // KeyboardInterruptType corresponds to the Python type 'KeyboardInterrupt'. + KeyboardInterruptType = newSimpleType("KeyboardInterrupt", BaseExceptionType) // KeyErrorType corresponds to the Python type 'KeyError'. KeyErrorType = newSimpleType("KeyError", LookupErrorType) // LookupErrorType corresponds to the Python type 'LookupError'. diff --git a/runtime/file.go b/runtime/file.go index a5b1c46c..8fac4496 100644 --- a/runtime/file.go +++ b/runtime/file.go @@ -35,16 +35,26 @@ type File struct { mutex sync.Mutex mode string open bool + Softspace int `attr:"softspace" attr_mode:"rw"` reader *bufio.Reader file *os.File skipNextLF bool univNewLine bool + close *Object } // NewFileFromFD creates a file object from the given file descriptor fd. -func NewFileFromFD(fd uintptr) *File { +func NewFileFromFD(fd uintptr, close *Object) *File { // TODO: Use fcntl or something to get the mode of the descriptor. - file := &File{Object: Object{typ: FileType}, mode: "?", open: true, file: os.NewFile(fd, "")} + file := &File{ + Object: Object{typ: FileType}, + mode: "?", + open: true, + file: os.NewFile(fd, ""), + } + if close != None { + file.close = close + } file.reader = bufio.NewReader(file.file) return file } @@ -53,6 +63,14 @@ func toFileUnsafe(o *Object) *File { return (*File)(o.toPointer()) } +func (f *File) name() string { + name := "" + if f.file != nil { + name = f.file.Name() + } + return name +} + // ToObject upcasts f to an Object. func (f *File) ToObject() *Object { return &f.Object @@ -89,6 +107,19 @@ func (f *File) readLine(maxBytes int) (string, error) { return buf.String(), nil } +func (f *File) writeString(s string) error { + f.mutex.Lock() + defer f.mutex.Unlock() + if !f.open { + return io.ErrClosedPipe + } + if _, err := f.file.Write([]byte(s)); err != nil { + return err + } + + return nil +} + // FileType is the object representing the Python 'file' type. var FileType = newBasisType("file", reflect.TypeOf(File{}), toFileUnsafe, ObjectType) @@ -114,6 +145,11 @@ func fileInit(f *Frame, o *Object, args Args, _ KWArgs) (*Object, *BaseException flag = os.O_RDONLY case "r+", "r+b": flag = os.O_RDWR + // Difference between r+ and a+ is that a+ automatically creates file. + case "a+": + flag = os.O_RDWR | os.O_CREATE | os.O_APPEND + case "w+": + flag = os.O_RDWR | os.O_CREATE case "w", "wb": flag = os.O_WRONLY | os.O_CREATE | os.O_TRUNC default: @@ -163,13 +199,59 @@ func fileClose(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { file := toFileUnsafe(args[0]) file.mutex.Lock() defer file.mutex.Unlock() - if file.open && file.file != nil { - if err := file.file.Close(); err != nil { - return nil, f.RaiseType(IOErrorType, err.Error()) + ret := None + if file.open { + var raised *BaseException + if file.close != nil { + ret, raised = file.close.Call(f, args, nil) + } else if file.file != nil { + if err := file.file.Close(); err != nil { + raised = f.RaiseType(IOErrorType, err.Error()) + } + } + if raised != nil { + return nil, raised } } file.open = false - return None, nil + return ret, nil +} + +func fileClosed(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "closed", args, FileType); raised != nil { + return nil, raised + } + file := toFileUnsafe(args[0]) + file.mutex.Lock() + c := !file.open + file.mutex.Unlock() + return GetBool(c).ToObject(), nil +} + +func fileFileno(f *Frame, args Args, _ KWArgs) (ret *Object, raised *BaseException) { + if raised := checkMethodArgs(f, "fileno", args, FileType); raised != nil { + return nil, raised + } + file := toFileUnsafe(args[0]) + file.mutex.Lock() + if file.open { + ret = NewInt(int(file.file.Fd())).ToObject() + } else { + raised = f.RaiseType(ValueErrorType, "I/O operation on closed file") + } + file.mutex.Unlock() + return ret, raised +} + +func fileGetName(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "_get_name", args, FileType); raised != nil { + return nil, raised + } + file := toFileUnsafe(args[0]) + file.mutex.Lock() + name := file.name() + file.mutex.Unlock() + return NewStr(name).ToObject(), nil } func fileIter(f *Frame, o *Object) (*Object, *BaseException) { @@ -213,7 +295,7 @@ func fileRead(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { n, err = file.reader.Read(data) data = data[:n] } - if err != nil { + if err != nil && err != io.EOF { return nil, f.RaiseType(IOErrorType, err.Error()) } return NewStr(string(data)).ToObject(), nil @@ -277,19 +359,13 @@ func fileRepr(f *Frame, o *Object) (*Object, *BaseException) { } else { openState = "closed" } - var name string - if file.file != nil { - name = file.file.Name() - } else { - name = "" - } var mode string if file.mode != "" { mode = file.mode } else { mode = "" } - return NewStr(fmt.Sprintf("<%s file %q, mode %q at %p>", openState, name, mode, file)).ToObject(), nil + return NewStr(fmt.Sprintf("<%s file %q, mode %q at %p>", openState, file.name(), mode, file)).ToObject(), nil } func fileWrite(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { @@ -313,6 +389,9 @@ func initFileType(dict map[string]*Object) { dict["__enter__"] = newBuiltinFunction("__enter__", fileEnter).ToObject() dict["__exit__"] = newBuiltinFunction("__exit__", fileExit).ToObject() dict["close"] = newBuiltinFunction("close", fileClose).ToObject() + dict["closed"] = newBuiltinFunction("closed", fileClosed).ToObject() + dict["fileno"] = newBuiltinFunction("fileno", fileFileno).ToObject() + dict["name"] = newProperty(newBuiltinFunction("_get_name", fileGetName).ToObject(), nil, nil).ToObject() dict["read"] = newBuiltinFunction("read", fileRead).ToObject() dict["readline"] = newBuiltinFunction("readline", fileReadLine).ToObject() dict["readlines"] = newBuiltinFunction("readlines", fileReadLines).ToObject() @@ -342,3 +421,12 @@ func fileParseReadArgs(f *Frame, method string, args Args) (*File, int, *BaseExc } return toFileUnsafe(args[0]), size, nil } + +var ( + // Stdin is an alias for sys.stdin. + Stdin = NewFileFromFD(os.Stdin.Fd(), nil) + // Stdout is an alias for sys.stdout. + Stdout = NewFileFromFD(os.Stdout.Fd(), nil) + // Stderr is an alias for sys.stderr. + Stderr = NewFileFromFD(os.Stderr.Fd(), nil) +) diff --git a/runtime/file_test.go b/runtime/file_test.go index 74227ac6..d292408d 100644 --- a/runtime/file_test.go +++ b/runtime/file_test.go @@ -34,7 +34,6 @@ func TestFileInit(t *testing.T) { {args: wrapArgs(newObject(FileType), f.path), want: None}, {args: wrapArgs(newObject(FileType)), wantExc: mustCreateException(TypeErrorType, "'__init__' requires 2 arguments")}, {args: wrapArgs(newObject(FileType), f.path, "abc"), wantExc: mustCreateException(ValueErrorType, `invalid mode string: "abc"`)}, - {args: wrapArgs(newObject(FileType), f.path, "w+"), wantExc: mustCreateException(ValueErrorType, `invalid mode string: "w+"`)}, {args: wrapArgs(newObject(FileType), "nonexistent-file"), wantExc: mustCreateException(IOErrorType, "open nonexistent-file: no such file or directory")}, } for _, cas := range cases { @@ -44,6 +43,25 @@ func TestFileInit(t *testing.T) { } } +func TestFileClosed(t *testing.T) { + f := newTestFile("foo\nbar") + defer f.cleanup() + closedFile := f.open("r") + // This puts the file into an invalid state since Grumpy thinks + // it's open even though the underlying file was closed. + closedFile.file.Close() + cases := []invokeTestCase{ + {args: wrapArgs(newObject(FileType)), want: True.ToObject()}, + {args: wrapArgs(f.open("r")), want: False.ToObject()}, + {args: wrapArgs(closedFile), want: False.ToObject()}, + } + for _, cas := range cases { + if err := runInvokeMethodTestCase(FileType, "closed", &cas); err != "" { + t.Error(err) + } + } +} + func TestFileCloseExit(t *testing.T) { f := newTestFile("foo\nbar") defer f.cleanup() @@ -55,7 +73,7 @@ func TestFileCloseExit(t *testing.T) { cases := []invokeTestCase{ {args: wrapArgs(newObject(FileType)), want: None}, {args: wrapArgs(f.open("r")), want: None}, - {args: wrapArgs(closedFile), wantExc: mustCreateException(IOErrorType, "invalid argument")}, + {args: wrapArgs(closedFile), wantExc: mustCreateException(IOErrorType, closedFile.file.Close().Error())}, } for _, cas := range cases { if err := runInvokeMethodTestCase(FileType, method, &cas); err != "" { @@ -65,6 +83,23 @@ func TestFileCloseExit(t *testing.T) { } } +func TestFileGetName(t *testing.T) { + fun := wrapFuncForTest(func(f *Frame, file *File) (*Object, *BaseException) { + return GetAttr(f, file.ToObject(), NewStr("name"), nil) + }) + foo := newTestFile("foo") + defer foo.cleanup() + cases := []invokeTestCase{ + {args: wrapArgs(foo.open("r")), want: NewStr(foo.path).ToObject()}, + {args: wrapArgs(newObject(FileType)), want: NewStr("").ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } +} + func TestFileIter(t *testing.T) { files := makeTestFiles() defer files.cleanup() @@ -303,7 +338,7 @@ func TestFileWrite(t *testing.T) { t.Fatalf("Chdir(%q) failed: %s", dir, err) } defer os.Chdir(oldWd) - for _, filename := range []string{"truncate.txt", "readonly.txt", "append.txt", "rplus.txt"} { + for _, filename := range []string{"truncate.txt", "readonly.txt", "append.txt", "rplus.txt", "aplus.txt", "wplus.txt"} { if err := ioutil.WriteFile(filename, []byte(filename), 0644); err != nil { t.Fatalf("ioutil.WriteFile(%q) failed: %s", filename, err) } @@ -312,7 +347,16 @@ func TestFileWrite(t *testing.T) { {args: wrapArgs("noexist.txt", "w", "foo\nbar"), want: NewStr("foo\nbar").ToObject()}, {args: wrapArgs("truncate.txt", "w", "new contents"), want: NewStr("new contents").ToObject()}, {args: wrapArgs("append.txt", "a", "\nbar"), want: NewStr("append.txt\nbar").ToObject()}, + {args: wrapArgs("rplus.txt", "r+", "fooey"), want: NewStr("fooey.txt").ToObject()}, + {args: wrapArgs("noexistplus1.txt", "r+", "pooey"), wantExc: mustCreateException(IOErrorType, "open noexistplus1.txt: no such file or directory")}, + + {args: wrapArgs("aplus.txt", "a+", "\napper"), want: NewStr("aplus.txt\napper").ToObject()}, + {args: wrapArgs("noexistplus3.txt", "a+", "snappbacktoreality"), want: NewStr("snappbacktoreality").ToObject()}, + + {args: wrapArgs("wplus.txt", "w+", "destructo"), want: NewStr("destructo").ToObject()}, + {args: wrapArgs("noexistplus2.txt", "w+", "wapper"), want: NewStr("wapper").ToObject()}, + {args: wrapArgs("readonly.txt", "r", "foo"), wantExc: mustCreateException(IOErrorType, "write readonly.txt: bad file descriptor")}, } for _, cas := range cases { diff --git a/runtime/float.go b/runtime/float.go index 8d89c232..4ee3ae73 100644 --- a/runtime/float.go +++ b/runtime/float.go @@ -20,6 +20,10 @@ import ( "math/big" "reflect" "strconv" + "strings" + "sync/atomic" + "unicode" + "unsafe" ) // FloatType is the object representing the Python 'float' type. @@ -29,11 +33,12 @@ var FloatType = newBasisType("float", reflect.TypeOf(Float{}), toFloatUnsafe, Ob type Float struct { Object value float64 + hash int } // NewFloat returns a new Float holding the given floating point value. func NewFloat(value float64) *Float { - return &Float{Object{typ: FloatType}, value} + return &Float{Object: Object{typ: FloatType}, value: value} } func toFloatUnsafe(o *Object) *Float { @@ -68,6 +73,16 @@ func floatDiv(f *Frame, v, w *Object) (*Object, *BaseException) { }) } +func floatDivMod(f *Frame, v, w *Object) (*Object, *BaseException) { + return floatDivAndModOp(f, "__divmod__", v, w, func(v, w float64) (float64, float64, bool) { + m, r := floatModFunc(v, w) + if !r { + return 0, 0, false + } + return math.Floor(v / w), m, true + }) +} + func floatEq(f *Frame, v, w *Object) (*Object, *BaseException) { return floatCompare(toFloatUnsafe(v), w, False, True, False), nil } @@ -76,6 +91,15 @@ func floatFloat(f *Frame, o *Object) (*Object, *BaseException) { return o, nil } +func floatFloorDiv(f *Frame, v, w *Object) (*Object, *BaseException) { + return floatDivModOp(f, "__floordiv__", v, w, func(v, w float64) (float64, bool) { + if w == 0.0 { + return 0, false + } + return math.Floor(v / w), true + }) +} + func floatGE(f *Frame, v, w *Object) (*Object, *BaseException) { return floatCompare(toFloatUnsafe(v), w, False, True, True), nil } @@ -84,13 +108,29 @@ func floatGetNewArgs(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { if raised := checkMethodArgs(f, "__getnewargs__", args, FloatType); raised != nil { return nil, raised } - return NewTuple(args[0]).ToObject(), nil + return NewTuple1(args[0]).ToObject(), nil } func floatGT(f *Frame, v, w *Object) (*Object, *BaseException) { return floatCompare(toFloatUnsafe(v), w, False, False, True), nil } +func floatHash(f *Frame, o *Object) (*Object, *BaseException) { + v := toFloatUnsafe(o) + p := (*unsafe.Pointer)(unsafe.Pointer(&v.hash)) + if lp := atomic.LoadPointer(p); lp != unsafe.Pointer(nil) { + return (*Int)(lp).ToObject(), nil + } + hash := hashFloat(v.Value()) + if hash == -1 { + hash-- + } + h := NewInt(hash) + atomic.StorePointer(p, unsafe.Pointer(h)) + + return h.ToObject(), nil +} + func floatInt(f *Frame, o *Object) (*Object, *BaseException) { val := toFloatUnsafe(o).Value() if math.IsInf(val, 0) { @@ -169,15 +209,11 @@ func floatNew(f *Frame, t *Type, args Args, _ KWArgs) (*Object, *BaseException) } o := args[0] if floatSlot := o.typ.slots.Float; floatSlot != nil { - result, raised := floatSlot.Fn(f, o) + fl, raised := floatConvert(floatSlot, f, o) if raised != nil { return nil, raised } - if raised == nil && !result.isInstance(FloatType) { - exc := fmt.Sprintf("__float__ returned non-float (type %s)", result.typ.Name()) - return nil, f.RaiseType(TypeErrorType, exc) - } - return result, nil + return fl.ToObject(), nil } if !o.isInstance(StrType) { return nil, f.RaiseType(TypeErrorType, "float() argument must be a string or a number") @@ -194,6 +230,10 @@ func floatNonZero(f *Frame, o *Object) (*Object, *BaseException) { return GetBool(toFloatUnsafe(o).Value() != 0).ToObject(), nil } +func floatPos(f *Frame, o *Object) (*Object, *BaseException) { + return o, nil +} + func floatPow(f *Frame, v, w *Object) (*Object, *BaseException) { return floatArithmeticOp(f, "__pow__", v, w, func(v, w float64) float64 { return math.Pow(v, w) }) } @@ -211,8 +251,32 @@ func floatRDiv(f *Frame, v, w *Object) (*Object, *BaseException) { }) } +func floatRDivMod(f *Frame, v, w *Object) (*Object, *BaseException) { + return floatDivAndModOp(f, "__rdivmod__", v, w, func(v, w float64) (float64, float64, bool) { + m, r := floatModFunc(w, v) + if !r { + return 0, 0, false + } + return w / v, m, true + }) +} + +const ( + floatReprPrecision = 16 + floatStrPrecision = 12 +) + func floatRepr(f *Frame, o *Object) (*Object, *BaseException) { - return NewStr(strconv.FormatFloat(toFloatUnsafe(o).Value(), 'g', -1, 64)).ToObject(), nil + return NewStr(floatToString(toFloatUnsafe(o).Value(), floatReprPrecision)).ToObject(), nil +} + +func floatRFloorDiv(f *Frame, v, w *Object) (*Object, *BaseException) { + return floatDivModOp(f, "__rfloordiv__", v, w, func(v, w float64) (float64, bool) { + if v == 0.0 { + return 0, false + } + return math.Floor(w / v), true + }) } func floatRMod(f *Frame, v, w *Object) (*Object, *BaseException) { @@ -233,6 +297,10 @@ func floatRSub(f *Frame, v, w *Object) (*Object, *BaseException) { return floatArithmeticOp(f, "__rsub__", v, w, func(v, w float64) float64 { return w - v }) } +func floatStr(f *Frame, o *Object) (*Object, *BaseException) { + return NewStr(floatToString(toFloatUnsafe(o).Value(), floatStrPrecision)).ToObject(), nil +} + func floatSub(f *Frame, v, w *Object) (*Object, *BaseException) { return floatArithmeticOp(f, "__sub__", v, w, func(v, w float64) float64 { return v - w }) } @@ -242,10 +310,13 @@ func initFloatType(dict map[string]*Object) { FloatType.slots.Abs = &unaryOpSlot{floatAbs} FloatType.slots.Add = &binaryOpSlot{floatAdd} FloatType.slots.Div = &binaryOpSlot{floatDiv} + FloatType.slots.DivMod = &binaryOpSlot{floatDivMod} FloatType.slots.Eq = &binaryOpSlot{floatEq} FloatType.slots.Float = &unaryOpSlot{floatFloat} + FloatType.slots.FloorDiv = &binaryOpSlot{floatFloorDiv} FloatType.slots.GE = &binaryOpSlot{floatGE} FloatType.slots.GT = &binaryOpSlot{floatGT} + FloatType.slots.Hash = &unaryOpSlot{floatHash} FloatType.slots.Int = &unaryOpSlot{floatInt} FloatType.slots.Long = &unaryOpSlot{floatLong} FloatType.slots.LE = &binaryOpSlot{floatLE} @@ -257,14 +328,18 @@ func initFloatType(dict map[string]*Object) { FloatType.slots.Neg = &unaryOpSlot{floatNeg} FloatType.slots.New = &newSlot{floatNew} FloatType.slots.NonZero = &unaryOpSlot{floatNonZero} + FloatType.slots.Pos = &unaryOpSlot{floatPos} FloatType.slots.Pow = &binaryOpSlot{floatPow} FloatType.slots.RAdd = &binaryOpSlot{floatRAdd} FloatType.slots.RDiv = &binaryOpSlot{floatRDiv} + FloatType.slots.RDivMod = &binaryOpSlot{floatRDivMod} FloatType.slots.Repr = &unaryOpSlot{floatRepr} + FloatType.slots.RFloorDiv = &binaryOpSlot{floatRFloorDiv} FloatType.slots.RMod = &binaryOpSlot{floatRMod} FloatType.slots.RMul = &binaryOpSlot{floatRMul} FloatType.slots.RPow = &binaryOpSlot{floatRPow} FloatType.slots.RSub = &binaryOpSlot{floatRSub} + FloatType.slots.Str = &unaryOpSlot{floatStr} FloatType.slots.Sub = &binaryOpSlot{floatSub} } @@ -335,6 +410,18 @@ func floatCoerce(o *Object) (float64, bool) { } } +func floatConvert(floatSlot *unaryOpSlot, f *Frame, o *Object) (*Float, *BaseException) { + result, raised := floatSlot.Fn(f, o) + if raised != nil { + return nil, raised + } + if !result.isInstance(FloatType) { + exc := fmt.Sprintf("__float__ returned non-float (type %s)", result.typ.Name()) + return nil, f.RaiseType(TypeErrorType, exc) + } + return toFloatUnsafe(result), nil +} + func floatDivModOp(f *Frame, method string, v, w *Object, fun func(v, w float64) (float64, bool)) (*Object, *BaseException) { floatW, ok := floatCoerce(w) if !ok { @@ -350,6 +437,55 @@ func floatDivModOp(f *Frame, method string, v, w *Object, fun func(v, w float64) return NewFloat(x).ToObject(), nil } +func floatDivAndModOp(f *Frame, method string, v, w *Object, fun func(v, w float64) (float64, float64, bool)) (*Object, *BaseException) { + floatW, ok := floatCoerce(w) + if !ok { + if math.IsInf(floatW, 0) { + return nil, f.RaiseType(OverflowErrorType, "long int too large to convert to float") + } + return NotImplemented, nil + } + q, m, ok := fun(toFloatUnsafe(v).Value(), floatW) + if !ok { + return nil, f.RaiseType(ZeroDivisionErrorType, "float division or modulo by zero") + } + return NewTuple2(NewFloat(q).ToObject(), NewFloat(m).ToObject()).ToObject(), nil +} + +func hashFloat(v float64) int { + if math.IsNaN(v) { + return 0 + } + + if math.IsInf(v, 0) { + if math.IsInf(v, 1) { + return 314159 + } + if math.IsInf(v, -1) { + return -271828 + } + return 0 + } + + _, fracPart := math.Modf(v) + if fracPart == 0.0 { + i := big.Int{} + big.NewFloat(v).Int(&i) + if numInIntRange(&i) { + return int(i.Int64()) + } + // TODO: hashBigInt() is not yet matched that of cpython or pypy. + return hashBigInt(&i) + } + + v, expo := math.Frexp(v) + v *= 2147483648.0 + hiPart := int(v) + v = (v - float64(hiPart)) * 2147483648.0 + x := int(hiPart + int(v) + (expo << 15)) + return x +} + func floatModFunc(v, w float64) (float64, bool) { if w == 0.0 { return 0, false @@ -366,3 +502,21 @@ func floatModFunc(v, w float64) (float64, bool) { } return x, true } + +func floatToString(f float64, p int) string { + s := unsignPositiveInf(strings.ToLower(strconv.FormatFloat(f, 'g', p, 64))) + fun := func(r rune) bool { + return !unicode.IsDigit(r) + } + if i := strings.IndexFunc(s, fun); i == -1 { + s += ".0" + } + return s +} + +func unsignPositiveInf(s string) string { + if s == "+inf" { + return "inf" + } + return s +} diff --git a/runtime/float_test.go b/runtime/float_test.go index 21d340f5..b5c37fb7 100644 --- a/runtime/float_test.go +++ b/runtime/float_test.go @@ -60,6 +60,16 @@ func TestFloatArithmeticOps(t *testing.T) { {Div, True.ToObject(), NewFloat(0).ToObject(), nil, mustCreateException(ZeroDivisionErrorType, "float division or modulo by zero")}, {Div, NewFloat(math.Inf(1)).ToObject(), NewFloat(0).ToObject(), nil, mustCreateException(ZeroDivisionErrorType, "float division or modulo by zero")}, {Div, NewFloat(1.0).ToObject(), NewLong(bigLongNumber).ToObject(), nil, mustCreateException(OverflowErrorType, "long int too large to convert to float")}, + {FloorDiv, NewFloat(12.5).ToObject(), NewFloat(4).ToObject(), NewFloat(3).ToObject(), nil}, + {FloorDiv, NewFloat(-12.5).ToObject(), NewInt(4).ToObject(), NewFloat(-4).ToObject(), nil}, + {FloorDiv, NewInt(25).ToObject(), NewFloat(5).ToObject(), NewFloat(5.0).ToObject(), nil}, + {FloorDiv, NewFloat(math.Inf(1)).ToObject(), NewFloat(math.Inf(1)).ToObject(), NewFloat(math.NaN()).ToObject(), nil}, + {FloorDiv, NewFloat(math.Inf(-1)).ToObject(), NewInt(-20).ToObject(), NewFloat(math.Inf(1)).ToObject(), nil}, + {FloorDiv, NewInt(1).ToObject(), NewFloat(math.Inf(1)).ToObject(), NewFloat(0).ToObject(), nil}, + {FloorDiv, newObject(ObjectType), NewFloat(1.1).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for //: 'object' and 'float'")}, + {FloorDiv, NewFloat(1.0).ToObject(), NewLong(bigLongNumber).ToObject(), nil, mustCreateException(OverflowErrorType, "long int too large to convert to float")}, + {FloorDiv, True.ToObject(), NewFloat(0).ToObject(), nil, mustCreateException(ZeroDivisionErrorType, "float division or modulo by zero")}, + {FloorDiv, NewFloat(math.Inf(1)).ToObject(), NewFloat(0).ToObject(), nil, mustCreateException(ZeroDivisionErrorType, "float division or modulo by zero")}, {Mod, NewFloat(50.5).ToObject(), NewInt(10).ToObject(), NewFloat(0.5).ToObject(), nil}, {Mod, NewFloat(50.5).ToObject(), NewFloat(-10).ToObject(), NewFloat(-9.5).ToObject(), nil}, {Mod, NewFloat(-20.2).ToObject(), NewFloat(40).ToObject(), NewFloat(19.8).ToObject(), nil}, @@ -101,6 +111,48 @@ func TestFloatArithmeticOps(t *testing.T) { } } +func TestFloatDivMod(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(12.5, 4.0), want: NewTuple2(NewFloat(3).ToObject(), NewFloat(0.5).ToObject()).ToObject()}, + {args: wrapArgs(-12.5, 4.0), want: NewTuple2(NewFloat(-4).ToObject(), NewFloat(3.5).ToObject()).ToObject()}, + {args: wrapArgs(25.0, 5.0), want: NewTuple2(NewFloat(5).ToObject(), NewFloat(0).ToObject()).ToObject()}, + {args: wrapArgs(-20.2, 40.0), want: NewTuple2(NewFloat(-1).ToObject(), NewFloat(19.8).ToObject()).ToObject()}, + {args: wrapArgs(math.Inf(1), math.Inf(1)), want: NewTuple2(NewFloat(math.NaN()).ToObject(), NewFloat(math.NaN()).ToObject()).ToObject()}, + {args: wrapArgs(math.Inf(1), math.Inf(-1)), want: NewTuple2(NewFloat(math.NaN()).ToObject(), NewFloat(math.NaN()).ToObject()).ToObject()}, + {args: wrapArgs(math.Inf(-1), -20.0), want: NewTuple2(NewFloat(math.Inf(1)).ToObject(), NewFloat(math.NaN()).ToObject()).ToObject()}, + {args: wrapArgs(1, math.Inf(1)), want: NewTuple2(NewFloat(0).ToObject(), NewFloat(1).ToObject()).ToObject()}, + {args: wrapArgs(newObject(ObjectType), 1.1), wantExc: mustCreateException(TypeErrorType, "unsupported operand type(s) for divmod(): 'object' and 'float'")}, + {args: wrapArgs(True.ToObject(), 0.0), wantExc: mustCreateException(ZeroDivisionErrorType, "float division or modulo by zero")}, + {args: wrapArgs(math.Inf(1), 0.0), wantExc: mustCreateException(ZeroDivisionErrorType, "float division or modulo by zero")}, + {args: wrapArgs(1.0, bigLongNumber), wantExc: mustCreateException(OverflowErrorType, "long int too large to convert to float")}, + } + for _, cas := range cases { + switch got, result := checkInvokeResult(wrapFuncForTest(DivMod), cas.args, cas.want, cas.wantExc); result { + case checkInvokeResultExceptionMismatch: + t.Errorf("float.__divmod__%v raised %v, want %v", cas.args, got, cas.wantExc) + case checkInvokeResultReturnValueMismatch: + // Handle NaN specially, since NaN != NaN. + if got == nil || cas.want == nil || !got.isInstance(TupleType) || !cas.want.isInstance(TupleType) || + !isNaNTupleFloat(got, cas.want) { + t.Errorf("float.__divmod__%v = %v, want %v", cas.args, got, cas.want) + } + } + } +} + +func isNaNTupleFloat(got, want *Object) bool { + if toTupleUnsafe(got).Len() != toTupleUnsafe(want).Len() { + return false + } + for i := 0; i < toTupleUnsafe(got).Len(); i++ { + if math.IsNaN(toFloatUnsafe(toTupleUnsafe(got).GetItem(i)).Value()) && + math.IsNaN(toFloatUnsafe(toTupleUnsafe(want).GetItem(i)).Value()) { + return true + } + } + return false +} + func TestFloatCompare(t *testing.T) { cases := []invokeTestCase{ {args: wrapArgs(1.0, 1.0), want: compareAllResultEq}, @@ -151,6 +203,22 @@ func TestFloatLong(t *testing.T) { } } +func TestFloatHash(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(NewFloat(0.0)), want: NewInt(0).ToObject()}, + {args: wrapArgs(NewFloat(3.14)), want: NewInt(3146129223).ToObject()}, + {args: wrapArgs(NewFloat(42.0)), want: NewInt(42).ToObject()}, + {args: wrapArgs(NewFloat(42.125)), want: NewInt(1413677056).ToObject()}, + {args: wrapArgs(NewFloat(math.Inf(1))), want: NewInt(314159).ToObject()}, + {args: wrapArgs(NewFloat(math.Inf(-1))), want: NewInt(-271828).ToObject()}, + {args: wrapArgs(NewFloat(math.NaN())), want: NewInt(0).ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(floatHash), &cas); err != "" { + t.Error(err) + } + } +} func TestFloatIsTrue(t *testing.T) { cases := []invokeTestCase{ {args: wrapArgs(0.0), want: False.ToObject()}, @@ -167,6 +235,10 @@ func TestFloatIsTrue(t *testing.T) { func TestFloatNew(t *testing.T) { floatNew := mustNotRaise(GetAttr(NewRootFrame(), FloatType.ToObject(), NewStr("__new__"), nil)) strictEqType := newTestClassStrictEq("StrictEq", FloatType) + newStrictEq := func(v float64) *Object { + f := Float{Object: Object{typ: strictEqType}, value: v} + return f.ToObject() + } subType := newTestClass("SubType", []*Type{FloatType}, newStringDict(map[string]*Object{})) subTypeObject := (&Float{Object: Object{typ: subType}, value: 3.14}).ToObject() goodSlotType := newTestClass("GoodSlot", []*Type{ObjectType}, newStringDict(map[string]*Object{ @@ -203,8 +275,8 @@ func TestFloatNew(t *testing.T) { {args: wrapArgs(FloatType, newObject(goodSlotType)), want: NewFloat(3.14).ToObject()}, {args: wrapArgs(FloatType, newObject(badSlotType)), wantExc: mustCreateException(TypeErrorType, "__float__ returned non-float (type object)")}, {args: wrapArgs(FloatType, newObject(slotSubTypeType)), want: subTypeObject}, - {args: wrapArgs(strictEqType, 3.14), want: (&Float{Object{typ: strictEqType}, 3.14}).ToObject()}, - {args: wrapArgs(strictEqType, newObject(goodSlotType)), want: (&Float{Object{typ: strictEqType}, 3.14}).ToObject()}, + {args: wrapArgs(strictEqType, 3.14), want: newStrictEq(3.14)}, + {args: wrapArgs(strictEqType, newObject(goodSlotType)), want: newStrictEq(3.14)}, {args: wrapArgs(strictEqType, newObject(badSlotType)), wantExc: mustCreateException(TypeErrorType, "__float__ returned non-float (type object)")}, {args: wrapArgs(), wantExc: mustCreateException(TypeErrorType, "'__new__' requires 1 arguments")}, {args: wrapArgs(IntType), wantExc: mustCreateException(TypeErrorType, "float.__new__(int): int is not a subtype of float")}, @@ -226,18 +298,43 @@ func TestFloatNew(t *testing.T) { } } -func TestFloatStrRepr(t *testing.T) { +func TestFloatRepr(t *testing.T) { cases := []invokeTestCase{ - {args: wrapArgs(0.0), want: NewStr("0").ToObject()}, + {args: wrapArgs(0.0), want: NewStr("0.0").ToObject()}, {args: wrapArgs(0.1), want: NewStr("0.1").ToObject()}, {args: wrapArgs(-303.5), want: NewStr("-303.5").ToObject()}, - {args: wrapArgs(231095835.0), want: NewStr("2.31095835e+08").ToObject()}, + {args: wrapArgs(231095835.0), want: NewStr("231095835.0").ToObject()}, + {args: wrapArgs(1e+6), want: NewStr("1000000.0").ToObject()}, + {args: wrapArgs(1e+15), want: NewStr("1000000000000000.0").ToObject()}, + {args: wrapArgs(1e+16), want: NewStr("1e+16").ToObject()}, + {args: wrapArgs(1E16), want: NewStr("1e+16").ToObject()}, + {args: wrapArgs(1e-6), want: NewStr("1e-06").ToObject()}, + {args: wrapArgs(math.Inf(1)), want: NewStr("inf").ToObject()}, + {args: wrapArgs(math.Inf(-1)), want: NewStr("-inf").ToObject()}, + {args: wrapArgs(math.NaN()), want: NewStr("nan").ToObject()}, } for _, cas := range cases { - if err := runInvokeTestCase(wrapFuncForTest(ToStr), &cas); err != "" { + if err := runInvokeTestCase(wrapFuncForTest(Repr), &cas); err != "" { t.Error(err) } - if err := runInvokeTestCase(wrapFuncForTest(Repr), &cas); err != "" { + } +} + +func TestFloatStr(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(1.0), want: NewStr("1.0").ToObject()}, + {args: wrapArgs(-847.373), want: NewStr("-847.373").ToObject()}, + {args: wrapArgs(0.123456789123456789), want: NewStr("0.123456789123").ToObject()}, + {args: wrapArgs(1e+11), want: NewStr("100000000000.0").ToObject()}, + {args: wrapArgs(1e+12), want: NewStr("1e+12").ToObject()}, + {args: wrapArgs(1e-4), want: NewStr("0.0001").ToObject()}, + {args: wrapArgs(1e-5), want: NewStr("1e-05").ToObject()}, + {args: wrapArgs(math.Inf(1)), want: NewStr("inf").ToObject()}, + {args: wrapArgs(math.Inf(-1)), want: NewStr("-inf").ToObject()}, + {args: wrapArgs(math.NaN()), want: NewStr("nan").ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(floatStr), &cas); err != "" { t.Error(err) } } diff --git a/runtime/frame.go b/runtime/frame.go index a098d9bb..d00cf4b0 100644 --- a/runtime/frame.go +++ b/runtime/frame.go @@ -70,7 +70,7 @@ func (f *Frame) release() { // TODO: Track cache depth and release memory. f.frameCache, f.back = f, f.frameCache // Clear pointers early. - f.dict = nil + f.setDict(nil) f.globals = nil f.code = nil } else if f.back != nil { @@ -270,6 +270,14 @@ func (f *Frame) FreeArgs(args Args) { // FrameType is the object representing the Python 'frame' type. var FrameType = newBasisType("frame", reflect.TypeOf(Frame{}), toFrameUnsafe, ObjectType) +func frameExcClear(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "__exc_clear__", args, FrameType); raised != nil { + return nil, raised + } + toFrameUnsafe(args[0]).RestoreExc(nil, nil) + return None, nil +} + func frameExcInfo(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { if raised := checkMethodVarArgs(f, "__exc_info__", args, FrameType); raised != nil { return nil, raised @@ -282,10 +290,11 @@ func frameExcInfo(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) if tb != nil { tbObj = tb.ToObject() } - return NewTuple(excObj, tbObj).ToObject(), nil + return NewTuple2(excObj, tbObj).ToObject(), nil } func initFrameType(dict map[string]*Object) { FrameType.flags &= ^(typeFlagInstantiable | typeFlagBasetype) + dict["__exc_clear__"] = newBuiltinFunction("__exc_clear__", frameExcClear).ToObject() dict["__exc_info__"] = newBuiltinFunction("__exc_info__", frameExcInfo).ToObject() } diff --git a/runtime/frame_test.go b/runtime/frame_test.go index c8bcb12d..80d2a42d 100644 --- a/runtime/frame_test.go +++ b/runtime/frame_test.go @@ -288,7 +288,7 @@ type invokeTestCase struct { func runInvokeTestCase(callable *Object, cas *invokeTestCase) string { f := NewRootFrame() - name := mustNotRaise(GetAttr(f, callable, NewStr("__name__"), NewStr("").ToObject())) + name := mustNotRaise(GetAttr(f, callable, internedName, NewStr("").ToObject())) if !name.isInstance(StrType) { return fmt.Sprintf("%v.__name__ is not a string", callable) } diff --git a/runtime/function.go b/runtime/function.go index f3ed41ba..1a6c8962 100644 --- a/runtime/function.go +++ b/runtime/function.go @@ -121,8 +121,14 @@ func functionCall(f *Frame, callable *Object, args Args, kwargs KWArgs) (*Object return code.Eval(f, fun.globals, args, kwargs) } -func functionGet(_ *Frame, desc, instance *Object, owner *Type) (*Object, *BaseException) { - return NewMethod(toFunctionUnsafe(desc), instance, owner).ToObject(), nil +func functionGet(f *Frame, desc, instance *Object, owner *Type) (*Object, *BaseException) { + args := f.MakeArgs(3) + args[0] = desc + args[1] = instance + args[2] = owner.ToObject() + ret, raised := MethodType.Call(f, args, nil) + f.FreeArgs(args) + return ret, raised } func functionRepr(_ *Frame, o *Object) (*Object, *BaseException) { @@ -201,7 +207,13 @@ func classMethodGet(f *Frame, desc, _ *Object, owner *Type) (*Object, *BaseExcep if m.callable == nil { return nil, f.RaiseType(RuntimeErrorType, "uninitialized classmethod object") } - return NewMethod(toFunctionUnsafe(m.callable), owner.ToObject(), owner).ToObject(), nil + args := f.MakeArgs(3) + args[0] = m.callable + args[1] = owner.ToObject() + args[2] = args[1] + ret, raised := MethodType.Call(f, args, nil) + f.FreeArgs(args) + return ret, raised } func classMethodInit(f *Frame, o *Object, args Args, _ KWArgs) (*Object, *BaseException) { diff --git a/runtime/function_test.go b/runtime/function_test.go index 57cb06e3..cf018bd5 100644 --- a/runtime/function_test.go +++ b/runtime/function_test.go @@ -49,7 +49,7 @@ func TestFunctionGet(t *testing.T) { func TestFunctionName(t *testing.T) { fun := newBuiltinFunction("TestFunctionName", func(f *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { foo := newBuiltinFunction("foo", func(*Frame, Args, KWArgs) (*Object, *BaseException) { return None, nil }) - return GetAttr(f, foo.ToObject(), NewStr("__name__"), nil) + return GetAttr(f, foo.ToObject(), internedName, nil) }).ToObject() if err := runInvokeTestCase(fun, &invokeTestCase{want: NewStr("foo").ToObject()}); err != "" { t.Error(err) @@ -126,12 +126,27 @@ func TestStaticMethodInit(t *testing.T) { } func TestClassMethodGet(t *testing.T) { + fun := wrapFuncForTest(func(f *Frame, meth *classMethod, args ...*Object) (*Object, *BaseException) { + get, raised := GetAttr(f, meth.ToObject(), NewStr("__get__"), nil) + if raised != nil { + return nil, raised + } + callable, raised := get.Call(f, args, nil) + if raised != nil { + return nil, raised + } + return callable.Call(f, nil, nil) + }) + echoFunc := wrapFuncForTest(func(f *Frame, args ...*Object) *Tuple { + return NewTuple(args...) + }) cases := []invokeTestCase{ - // {args: wrapArgs(newClassMethod(NewStr("abc").ToObject()), 123, IntType), want: NewStr("abc").ToObject()}, + {args: wrapArgs(newClassMethod(echoFunc), ObjectType, ObjectType), want: NewTuple(ObjectType.ToObject()).ToObject()}, + {args: wrapArgs(newClassMethod(NewStr("abc").ToObject()), 123, IntType), wantExc: mustCreateException(TypeErrorType, "first argument must be callable")}, {args: wrapArgs(newClassMethod(nil), 123, IntType), wantExc: mustCreateException(RuntimeErrorType, "uninitialized classmethod object")}, } for _, cas := range cases { - if err := runInvokeMethodTestCase(ClassMethodType, "__get__", &cas); err != "" { + if err := runInvokeTestCase(fun, &cas); err != "" { t.Error(err) } } diff --git a/runtime/int.go b/runtime/int.go index 87546ef3..bf9ad654 100644 --- a/runtime/int.go +++ b/runtime/int.go @@ -90,6 +90,10 @@ func intDiv(f *Frame, v, w *Object) (*Object, *BaseException) { return intDivModOp(f, "__div__", v, w, intCheckedDiv, longDiv) } +func intDivMod(f *Frame, v, w *Object) (*Object, *BaseException) { + return intDivAndModOp(f, "__divmod__", v, w, intCheckedDivMod, longDivAndMod) +} + func intEq(f *Frame, v, w *Object) (*Object, *BaseException) { return intCompare(compareOpEq, toIntUnsafe(v), w), nil } @@ -102,7 +106,7 @@ func intGetNewArgs(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { if raised := checkMethodArgs(f, "__getnewargs__", args, IntType); raised != nil { return nil, raised } - return NewTuple(args[0]).ToObject(), nil + return NewTuple1(args[0]).ToObject(), nil } func intGT(f *Frame, v, w *Object) (*Object, *BaseException) { @@ -118,6 +122,11 @@ func intHash(f *Frame, o *Object) (*Object, *BaseException) { return o, nil } +func intHex(f *Frame, o *Object) (*Object, *BaseException) { + val := numberToBase("0x", 16, o) + return NewStr(val).ToObject(), nil +} + func intIndex(f *Frame, o *Object) (*Object, *BaseException) { return o, nil } @@ -244,6 +253,14 @@ func intNonZero(f *Frame, o *Object) (*Object, *BaseException) { return GetBool(toIntUnsafe(o).Value() != 0).ToObject(), nil } +func intOct(f *Frame, o *Object) (*Object, *BaseException) { + val := numberToBase("0", 8, o) + if val == "00" { + val = "0" + } + return NewStr(val).ToObject(), nil +} + func intOr(f *Frame, v, w *Object) (*Object, *BaseException) { if !w.isInstance(IntType) { return NotImplemented, nil @@ -251,6 +268,10 @@ func intOr(f *Frame, v, w *Object) (*Object, *BaseException) { return NewInt(toIntUnsafe(v).Value() | toIntUnsafe(w).Value()).ToObject(), nil } +func intPos(f *Frame, o *Object) (*Object, *BaseException) { + return o, nil +} + func intPow(f *Frame, v, w *Object) (*Object, *BaseException) { if w.isInstance(IntType) { // First try to use the faster floating point arithmetic @@ -300,6 +321,14 @@ func intRDiv(f *Frame, v, w *Object) (*Object, *BaseException) { }) } +func intRDivMod(f *Frame, v, w *Object) (*Object, *BaseException) { + return intDivAndModOp(f, "__rdivmod__", v, w, func(v, w int) (int, int, divModResult) { + return intCheckedDivMod(w, v) + }, func(z, m, x, y *big.Int) { + longDivAndMod(z, m, y, x) + }) +} + func intRepr(f *Frame, o *Object) (*Object, *BaseException) { return NewStr(strconv.FormatInt(int64(toIntUnsafe(o).Value()), 10)).ToObject(), nil } @@ -353,11 +382,14 @@ func initIntType(dict map[string]*Object) { IntType.slots.Add = &binaryOpSlot{intAdd} IntType.slots.And = &binaryOpSlot{intAnd} IntType.slots.Div = &binaryOpSlot{intDiv} + IntType.slots.DivMod = &binaryOpSlot{intDivMod} IntType.slots.Eq = &binaryOpSlot{intEq} + IntType.slots.FloorDiv = &binaryOpSlot{intDiv} IntType.slots.GE = &binaryOpSlot{intGE} IntType.slots.GT = &binaryOpSlot{intGT} IntType.slots.Float = &unaryOpSlot{intFloat} IntType.slots.Hash = &unaryOpSlot{intHash} + IntType.slots.Hex = &unaryOpSlot{intHex} IntType.slots.Index = &unaryOpSlot{intIndex} IntType.slots.Int = &unaryOpSlot{intInt} IntType.slots.Invert = &unaryOpSlot{intInvert} @@ -372,12 +404,16 @@ func initIntType(dict map[string]*Object) { IntType.slots.Neg = &unaryOpSlot{intNeg} IntType.slots.New = &newSlot{intNew} IntType.slots.NonZero = &unaryOpSlot{intNonZero} + IntType.slots.Oct = &unaryOpSlot{intOct} IntType.slots.Or = &binaryOpSlot{intOr} + IntType.slots.Pos = &unaryOpSlot{intPos} IntType.slots.Pow = &binaryOpSlot{intPow} IntType.slots.RAdd = &binaryOpSlot{intRAdd} IntType.slots.RAnd = &binaryOpSlot{intAnd} IntType.slots.RDiv = &binaryOpSlot{intRDiv} + IntType.slots.RDivMod = &binaryOpSlot{intRDivMod} IntType.slots.Repr = &unaryOpSlot{intRepr} + IntType.slots.RFloorDiv = &binaryOpSlot{intRDiv} IntType.slots.RMod = &binaryOpSlot{intRMod} IntType.slots.RMul = &binaryOpSlot{intRMul} IntType.slots.ROr = &binaryOpSlot{intOr} @@ -516,6 +552,20 @@ func intDivModOp(f *Frame, method string, v, w *Object, fun func(v, w int) (int, return NewInt(x).ToObject(), nil } +func intDivAndModOp(f *Frame, method string, v, w *Object, fun func(v, w int) (int, int, divModResult), bigFun func(z, m, x, y *big.Int)) (*Object, *BaseException) { + if !w.isInstance(IntType) { + return NotImplemented, nil + } + q, m, r := fun(toIntUnsafe(v).Value(), toIntUnsafe(w).Value()) + switch r { + case divModOverflow: + return longCallBinaryTuple(bigFun, intToLong(toIntUnsafe(v)), intToLong(toIntUnsafe(w))), nil + case divModZeroDivision: + return nil, f.RaiseType(ZeroDivisionErrorType, "integer division or modulo by zero") + } + return NewTuple2(NewInt(q).ToObject(), NewInt(m).ToObject()).ToObject(), nil +} + func intShiftOp(f *Frame, v, w *Object, fun func(int, int) (int, int, bool)) (*Object, *BaseException) { if !w.isInstance(IntType) { return NotImplemented, nil diff --git a/runtime/int_test.go b/runtime/int_test.go index 33e35f6e..b2714f8c 100644 --- a/runtime/int_test.go +++ b/runtime/int_test.go @@ -39,6 +39,19 @@ func TestIntBinaryOps(t *testing.T) { {Div, NewList().ToObject(), NewInt(21).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for /: 'list' and 'int'")}, {Div, NewInt(1).ToObject(), NewInt(0).ToObject(), nil, mustCreateException(ZeroDivisionErrorType, "integer division or modulo by zero")}, {Div, NewInt(MinInt).ToObject(), NewInt(-1).ToObject(), NewLong(new(big.Int).Neg(minIntBig)).ToObject(), nil}, + {DivMod, NewInt(7).ToObject(), NewInt(3).ToObject(), NewTuple2(NewInt(2).ToObject(), NewInt(1).ToObject()).ToObject(), nil}, + {DivMod, NewInt(3).ToObject(), NewInt(-7).ToObject(), NewTuple2(NewInt(-1).ToObject(), NewInt(-4).ToObject()).ToObject(), nil}, + {DivMod, NewInt(MaxInt).ToObject(), NewInt(MinInt).ToObject(), NewTuple2(NewInt(-1).ToObject(), NewInt(-1).ToObject()).ToObject(), nil}, + {DivMod, NewInt(MinInt).ToObject(), NewInt(MaxInt).ToObject(), NewTuple2(NewInt(-2).ToObject(), NewInt(MaxInt-1).ToObject()).ToObject(), nil}, + {DivMod, NewList().ToObject(), NewInt(21).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for divmod(): 'list' and 'int'")}, + {DivMod, NewInt(1).ToObject(), NewInt(0).ToObject(), nil, mustCreateException(ZeroDivisionErrorType, "integer division or modulo by zero")}, + {DivMod, NewInt(MinInt).ToObject(), NewInt(-1).ToObject(), NewTuple2(NewLong(new(big.Int).Neg(minIntBig)).ToObject(), NewLong(big.NewInt(0)).ToObject()).ToObject(), nil}, + {FloorDiv, NewInt(7).ToObject(), NewInt(3).ToObject(), NewInt(2).ToObject(), nil}, + {FloorDiv, NewInt(MaxInt).ToObject(), NewInt(MinInt).ToObject(), NewInt(-1).ToObject(), nil}, + {FloorDiv, NewInt(MinInt).ToObject(), NewInt(MaxInt).ToObject(), NewInt(-2).ToObject(), nil}, + {FloorDiv, NewList().ToObject(), NewInt(21).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for //: 'list' and 'int'")}, + {FloorDiv, NewInt(1).ToObject(), NewInt(0).ToObject(), nil, mustCreateException(ZeroDivisionErrorType, "integer division or modulo by zero")}, + {FloorDiv, NewInt(MinInt).ToObject(), NewInt(-1).ToObject(), NewLong(new(big.Int).Neg(minIntBig)).ToObject(), nil}, {LShift, NewInt(2).ToObject(), NewInt(4).ToObject(), NewInt(32).ToObject(), nil}, {LShift, NewInt(-12).ToObject(), NewInt(10).ToObject(), NewInt(-12288).ToObject(), nil}, {LShift, NewInt(10).ToObject(), NewInt(100).ToObject(), NewLong(new(big.Int).Lsh(big.NewInt(10), 100)).ToObject(), nil}, diff --git a/runtime/list.go b/runtime/list.go index 15046a63..41e4c442 100644 --- a/runtime/list.go +++ b/runtime/list.go @@ -61,20 +61,61 @@ func (l *List) Append(o *Object) { l.mutex.Unlock() } +// DelItem removes the index'th element of l. +func (l *List) DelItem(f *Frame, index int) *BaseException { + l.mutex.Lock() + numElems := len(l.elems) + i, raised := seqCheckedIndex(f, numElems, index) + if raised == nil { + copy(l.elems[i:numElems-1], l.elems[i+1:numElems]) + l.elems = l.elems[:numElems-1] + } + l.mutex.Unlock() + return raised +} + +// DelSlice removes the slice of l specified by s. +func (l *List) DelSlice(f *Frame, s *Slice) *BaseException { + l.mutex.Lock() + numListElems := len(l.elems) + start, stop, step, numSliceElems, raised := s.calcSlice(f, numListElems) + if raised == nil { + if step == 1 { + copy(l.elems[start:numListElems-numSliceElems], l.elems[stop:numListElems]) + } else { + j := 0 + for i := start; i != stop; i += step { + next := i + step + if next > numListElems { + next = numListElems + } + dest := l.elems[i-j : next-j-1] + src := l.elems[i+1 : next] + copy(dest, src) + j++ + } + } + l.elems = l.elems[:numListElems-numSliceElems] + } + l.mutex.Unlock() + return raised +} + // SetItem sets the index'th element of l to value. func (l *List) SetItem(f *Frame, index int, value *Object) *BaseException { - l.mutex.RLock() + l.mutex.Lock() i, raised := seqCheckedIndex(f, len(l.elems), index) if raised == nil { l.elems[i] = value } - l.mutex.RUnlock() + l.mutex.Unlock() return raised } // SetSlice replaces the slice of l specified by s with the contents of value // (an iterable). func (l *List) SetSlice(f *Frame, s *Slice, value *Object) *BaseException { + l.mutex.Lock() numListElems := len(l.elems) start, stop, step, numSliceElems, raised := s.calcSlice(f, numListElems) if raised == nil { @@ -98,6 +139,7 @@ func (l *List) SetSlice(f *Frame, s *Slice, value *Object) *BaseException { return nil }) } + l.mutex.Unlock() return raised } @@ -164,13 +206,50 @@ func listAppend(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { } func listCount(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { - argc := len(args) - if argc != 2 { - return nil, f.RaiseType(TypeErrorType, fmt.Sprintf("count() takes exactly one argument (%d given)", argc)) + if raised := checkMethodArgs(f, "count", args, ListType, ObjectType); raised != nil { + return nil, raised } return seqCount(f, args[0], args[1]) } +func listDelItem(f *Frame, o *Object, key *Object) *BaseException { + l := toListUnsafe(o) + if key.isInstance(SliceType) { + return l.DelSlice(f, toSliceUnsafe(key)) + } + if key.typ.slots.Index == nil { + format := "list indices must be integers, not %s" + return f.RaiseType(TypeErrorType, fmt.Sprintf(format, key.Type().Name())) + } + index, raised := IndexInt(f, key) + if raised != nil { + return raised + } + return l.DelItem(f, index) +} + +func listRemove(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "remove", args, ListType, ObjectType); raised != nil { + return nil, raised + } + value := args[1] + l := toListUnsafe(args[0]) + l.mutex.Lock() + index, raised := seqFindElem(f, l.elems, value) + if raised == nil { + if index != -1 { + l.elems = append(l.elems[:index], l.elems[index+1:]...) + } else { + raised = f.RaiseType(ValueErrorType, "list.remove(x): x not in list") + } + } + l.mutex.Unlock() + if raised != nil { + return nil, raised + } + return None, nil +} + func listExtend(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { argc := len(args) if argc != 2 { @@ -307,6 +386,50 @@ func listNE(f *Frame, v, w *Object) (*Object, *BaseException) { return listCompare(f, toListUnsafe(v), w, NE) } +func listIndex(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + expectedTypes := []*Type{ListType, ObjectType, ObjectType, ObjectType} + argc := len(args) + var raised *BaseException + if argc == 2 || argc == 3 { + expectedTypes = expectedTypes[:argc] + } + if raised = checkMethodArgs(f, "index", args, expectedTypes...); raised != nil { + return nil, raised + } + l := toListUnsafe(args[0]) + l.mutex.RLock() + numElems := len(l.elems) + start, stop := 0, numElems + if argc > 2 { + start, raised = IndexInt(f, args[2]) + if raised != nil { + l.mutex.RUnlock() + return nil, raised + } + } + if argc > 3 { + stop, raised = IndexInt(f, args[3]) + if raised != nil { + l.mutex.RUnlock() + return nil, raised + } + } + start, stop = adjustIndex(start, stop, numElems) + value := args[1] + index := -1 + if start < numElems && start < stop { + index, raised = seqFindElem(f, l.elems[start:stop], value) + } + l.mutex.RUnlock() + if raised != nil { + return nil, raised + } + if index == -1 { + return nil, f.RaiseType(ValueErrorType, fmt.Sprintf("%v is not in list", value)) + } + return NewInt(index + start).ToObject(), nil +} + func listPop(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { argc := len(args) expectedTypes := []*Type{ListType, ObjectType} @@ -374,7 +497,7 @@ func listReverse(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { func listSetItem(f *Frame, o, key, value *Object) *BaseException { l := toListUnsafe(o) - if key.typ.slots.Int != nil { + if key.typ.slots.Index != nil { i, raised := IndexInt(f, key) if raised != nil { return raised @@ -401,12 +524,15 @@ func initListType(dict map[string]*Object) { dict["append"] = newBuiltinFunction("append", listAppend).ToObject() dict["count"] = newBuiltinFunction("count", listCount).ToObject() dict["extend"] = newBuiltinFunction("extend", listExtend).ToObject() + dict["index"] = newBuiltinFunction("index", listIndex).ToObject() dict["insert"] = newBuiltinFunction("insert", listInsert).ToObject() dict["pop"] = newBuiltinFunction("pop", listPop).ToObject() + dict["remove"] = newBuiltinFunction("remove", listRemove).ToObject() dict["reverse"] = newBuiltinFunction("reverse", listReverse).ToObject() dict["sort"] = newBuiltinFunction("sort", listSort).ToObject() ListType.slots.Add = &binaryOpSlot{listAdd} ListType.slots.Contains = &binaryOpSlot{listContains} + ListType.slots.DelItem = &delItemSlot{listDelItem} ListType.slots.Eq = &binaryOpSlot{listEq} ListType.slots.GE = &binaryOpSlot{listGE} ListType.slots.GetItem = &binaryOpSlot{listGetItem} diff --git a/runtime/list_test.go b/runtime/list_test.go index e0525697..3b29a2c4 100644 --- a/runtime/list_test.go +++ b/runtime/list_test.go @@ -83,7 +83,7 @@ func TestListCount(t *testing.T) { cases := []invokeTestCase{ {args: wrapArgs(NewList(), NewInt(1)), want: NewInt(0).ToObject()}, {args: wrapArgs(NewList(None, None, None), None), want: NewInt(3).ToObject()}, - {args: wrapArgs(newTestList()), wantExc: mustCreateException(TypeErrorType, "count() takes exactly one argument (1 given)")}, + {args: wrapArgs(newTestList()), wantExc: mustCreateException(TypeErrorType, "'count' of 'list' requires 2 arguments")}, } for _, cas := range cases { if err := runInvokeMethodTestCase(ListType, "count", &cas); err != "" { @@ -92,6 +92,100 @@ func TestListCount(t *testing.T) { } } +func TestListDelItem(t *testing.T) { + badIndexType := newTestClass("badIndex", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__index__": newBuiltinFunction("__index__", func(f *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { + return nil, f.RaiseType(ValueErrorType, "wut") + }).ToObject(), + })) + delItem := mustNotRaise(GetAttr(NewRootFrame(), ListType.ToObject(), NewStr("__delitem__"), nil)) + fun := wrapFuncForTest(func(f *Frame, l *List, key *Object) (*Object, *BaseException) { + _, raised := delItem.Call(f, wrapArgs(l, key), nil) + if raised != nil { + return nil, raised + } + return l.ToObject(), nil + }) + cases := []invokeTestCase{ + {args: wrapArgs(newTestRange(3), 0), want: newTestList(1, 2).ToObject()}, + {args: wrapArgs(newTestRange(3), 2), want: newTestList(0, 1).ToObject()}, + {args: wrapArgs(NewList(), 101), wantExc: mustCreateException(IndexErrorType, "index out of range")}, + {args: wrapArgs(NewList(), newTestSlice(50, 100)), want: NewList().ToObject()}, + {args: wrapArgs(newTestList(1, 2, 3, 4, 5), newTestSlice(1, 3, None)), want: newTestList(1, 4, 5).ToObject()}, + {args: wrapArgs(newTestList(1, 2, 3, 4, 5), newTestSlice(1, None, 2)), want: newTestList(1, 3, 5).ToObject()}, + {args: wrapArgs(newTestList(1, 2, 3, 4, 5), newTestSlice(big.NewInt(1), None, 2)), want: newTestList(1, 3, 5).ToObject()}, + {args: wrapArgs(newTestList(1, 2, 3, 4, 5), newTestSlice(1, big.NewInt(5), 2)), want: newTestList(1, 3, 5).ToObject()}, + {args: wrapArgs(newTestList(1, 2, 3, 4, 5), newTestSlice(1, None, big.NewInt(2))), want: newTestList(1, 3, 5).ToObject()}, + {args: wrapArgs(newTestList(1, 2, 3, 4, 5), newTestSlice(1.0, 3, None)), wantExc: mustCreateException(TypeErrorType, errBadSliceIndex)}, + {args: wrapArgs(newTestList(1, 2, 3, 4, 5), newTestSlice(None, None, 4)), want: newTestList(2, 3, 4).ToObject()}, + {args: wrapArgs(newTestRange(10), newTestSlice(1, 8, 3)), want: newTestList(0, 2, 3, 5, 6, 8, 9).ToObject()}, + {args: wrapArgs(newTestList(1, 2, 3), newTestSlice(1, None, 0)), wantExc: mustCreateException(ValueErrorType, "slice step cannot be zero")}, + {args: wrapArgs(newTestList(true), None), wantExc: mustCreateException(TypeErrorType, "list indices must be integers, not NoneType")}, + {args: wrapArgs(newTestList(true), newObject(badIndexType)), wantExc: mustCreateException(ValueErrorType, "wut")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } +} + +func TestListIndex(t *testing.T) { + intIndexType := newTestClass("IntIndex", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__index__": newBuiltinFunction("__index__", func(f *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { + return NewInt(0).ToObject(), nil + }).ToObject(), + })) + cases := []invokeTestCase{ + // {args: wrapArgs(newTestList(), 1, "foo"), wantExc: mustCreateException(TypeErrorType, "slice indices must be integers or None or have an __index__ method")}, + {args: wrapArgs(newTestList(10, 20, 30), 20), want: NewInt(1).ToObject()}, + {args: wrapArgs(newTestList(10, 20, 30), 20, newObject(intIndexType)), want: NewInt(1).ToObject()}, + {args: wrapArgs(newTestList(0, "foo", "bar"), "foo"), want: NewInt(1).ToObject()}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 3), want: NewInt(3).ToObject()}, + {args: wrapArgs(newTestList(0, 2.0, 2, 3, 4, 2, 1, "foo"), 3, 3), want: NewInt(3).ToObject()}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 4), wantExc: mustCreateException(ValueErrorType, "3 is not in list")}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 0, 4), want: NewInt(3).ToObject()}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 0, 3), wantExc: mustCreateException(ValueErrorType, "3 is not in list")}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, -2), want: NewInt(3).ToObject()}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, -1), wantExc: mustCreateException(ValueErrorType, "3 is not in list")}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 0, -1), want: NewInt(3).ToObject()}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 0, -2), wantExc: mustCreateException(ValueErrorType, "3 is not in list")}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 0, 999), want: NewInt(3).ToObject()}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), "foo", 0, 999), wantExc: mustCreateException(ValueErrorType, "'foo' is not in list")}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 999), wantExc: mustCreateException(ValueErrorType, "3 is not in list")}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 5, 0), wantExc: mustCreateException(ValueErrorType, "3 is not in list")}, + } + for _, cas := range cases { + if err := runInvokeMethodTestCase(ListType, "index", &cas); err != "" { + t.Error(err) + } + } +} + +func TestListRemove(t *testing.T) { + fun := newBuiltinFunction("TestListRemove", func(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + rem, raised := GetAttr(f, ListType.ToObject(), NewStr("remove"), nil) + if raised != nil { + return nil, raised + } + if _, raised := rem.Call(f, args, nil); raised != nil { + return nil, raised + } + return args[0], nil + }).ToObject() + cases := []invokeTestCase{ + {args: wrapArgs(newTestList(1, 2, 3), 2), want: newTestList(1, 3).ToObject()}, + {args: wrapArgs(newTestList(1, 2, 3, 2, 1), 2), want: newTestList(1, 3, 2, 1).ToObject()}, + {args: wrapArgs(NewList()), wantExc: mustCreateException(TypeErrorType, "'remove' of 'list' requires 2 arguments")}, + {args: wrapArgs(NewList(), 1), wantExc: mustCreateException(ValueErrorType, "list.remove(x): x not in list")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } +} + func BenchmarkListContains(b *testing.B) { b.Run("false-3", func(b *testing.B) { t := newTestList("foo", 42, "bar").ToObject() diff --git a/runtime/long.go b/runtime/long.go index 83f10fa6..77dae9ad 100644 --- a/runtime/long.go +++ b/runtime/long.go @@ -112,6 +112,10 @@ func longDiv(z, x, y *big.Int) { longDivMod(x, y, z, &m) } +func longDivAndMod(z, m, x, y *big.Int) { + longDivMod(x, y, z, m) +} + func longEq(x, y *big.Int) bool { return x.Cmp(y) == 0 } @@ -124,7 +128,7 @@ func longGetNewArgs(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { if raised := checkMethodArgs(f, "__getnewargs__", args, LongType); raised != nil { return nil, raised } - return NewTuple(args[0]).ToObject(), nil + return NewTuple1(args[0]).ToObject(), nil } func longGT(x, y *big.Int) bool { @@ -144,6 +148,11 @@ func hashBigInt(x *big.Int) int { return hashString(x.Text(36)) } +func longHex(f *Frame, o *Object) (*Object, *BaseException) { + val := numberToBase("0x", 16, o) + "L" + return NewStr(val).ToObject(), nil +} + func longHash(f *Frame, o *Object) (*Object, *BaseException) { l := toLongUnsafe(o) l.hashOnce.Do(func() { @@ -290,10 +299,22 @@ func longNonZero(x *big.Int) bool { return x.Sign() != 0 } +func longOct(f *Frame, o *Object) (*Object, *BaseException) { + val := numberToBase("0", 8, o) + "L" + if val == "00L" { + val = "0L" + } + return NewStr(val).ToObject(), nil +} + func longOr(z, x, y *big.Int) { z.Or(x, y) } +func longPos(z, x *big.Int) { + z.Set(x) +} + func longRepr(f *Frame, o *Object) (*Object, *BaseException) { return NewStr(toLongUnsafe(o).value.Text(10) + "L").ToObject(), nil } @@ -320,11 +341,14 @@ func initLongType(dict map[string]*Object) { LongType.slots.Add = longBinaryOpSlot(longAdd) LongType.slots.And = longBinaryOpSlot(longAnd) LongType.slots.Div = longDivModOpSlot(longDiv) + LongType.slots.DivMod = longDivAndModOpSlot(longDivAndMod) LongType.slots.Eq = longBinaryBoolOpSlot(longEq) LongType.slots.Float = &unaryOpSlot{longFloat} + LongType.slots.FloorDiv = longDivModOpSlot(longDiv) LongType.slots.GE = longBinaryBoolOpSlot(longGE) LongType.slots.GT = longBinaryBoolOpSlot(longGT) LongType.slots.Hash = &unaryOpSlot{longHash} + LongType.slots.Hex = &unaryOpSlot{longHex} LongType.slots.Index = &unaryOpSlot{longIndex} LongType.slots.Int = &unaryOpSlot{longInt} LongType.slots.Invert = longUnaryOpSlot(longInvert) @@ -339,13 +363,17 @@ func initLongType(dict map[string]*Object) { LongType.slots.Neg = longUnaryOpSlot(longNeg) LongType.slots.New = &newSlot{longNew} LongType.slots.NonZero = longUnaryBoolOpSlot(longNonZero) + LongType.slots.Oct = &unaryOpSlot{longOct} LongType.slots.Or = longBinaryOpSlot(longOr) + LongType.slots.Pos = longUnaryOpSlot(longPos) // This operation can return a float, it must use binaryOpSlot directly. LongType.slots.Pow = &binaryOpSlot{longPow} LongType.slots.RAdd = longRBinaryOpSlot(longAdd) LongType.slots.RAnd = longRBinaryOpSlot(longAnd) LongType.slots.RDiv = longRDivModOpSlot(longDiv) + LongType.slots.RDivMod = longRDivAndModOpSlot(longDivAndMod) LongType.slots.Repr = &unaryOpSlot{longRepr} + LongType.slots.RFloorDiv = longRDivModOpSlot(longDiv) LongType.slots.RMod = longRDivModOpSlot(longMod) LongType.slots.RMul = longRBinaryOpSlot(longMul) LongType.slots.ROr = longRBinaryOpSlot(longOr) @@ -377,6 +405,13 @@ func longCallBinary(fun func(z, x, y *big.Int), v, w *Long) *Object { return l.ToObject() } +func longCallBinaryTuple(fun func(z, m, x, y *big.Int), v, w *Long) *Object { + l := Long{Object: Object{typ: LongType}} + ll := Long{Object: Object{typ: LongType}} + fun(&l.value, &ll.value, &v.value, &w.value) + return NewTuple2(l.ToObject(), ll.ToObject()).ToObject() +} + func longCallBinaryBool(fun func(x, y *big.Int) bool, v, w *Long) *Object { return GetBool(fun(&v.value, &w.value)).ToObject() } @@ -400,6 +435,13 @@ func longCallDivMod(fun func(z, x, y *big.Int), f *Frame, v, w *Long) (*Object, return longCallBinary(fun, v, w), nil } +func longCallDivAndMod(fun func(z, m, x, y *big.Int), f *Frame, v, w *Long) (*Object, *BaseException) { + if w.value.Sign() == 0 { + return nil, f.RaiseType(ZeroDivisionErrorType, "integer division or modulo by zero") + } + return longCallBinaryTuple(fun, v, w), nil +} + func longUnaryOpSlot(fun func(z, x *big.Int)) *unaryOpSlot { f := func(_ *Frame, v *Object) (*Object, *BaseException) { return longCallUnary(fun, toLongUnsafe(v)), nil @@ -462,6 +504,30 @@ func longRDivModOpSlot(fun func(z, x, y *big.Int)) *binaryOpSlot { return &binaryOpSlot{f} } +func longDivAndModOpSlot(fun func(z, m, x, y *big.Int)) *binaryOpSlot { + f := func(f *Frame, v, w *Object) (*Object, *BaseException) { + if w.isInstance(IntType) { + w = intToLong(toIntUnsafe(w)).ToObject() + } else if !w.isInstance(LongType) { + return NotImplemented, nil + } + return longCallDivAndMod(fun, f, toLongUnsafe(v), toLongUnsafe(w)) + } + return &binaryOpSlot{f} +} + +func longRDivAndModOpSlot(fun func(z, m, x, y *big.Int)) *binaryOpSlot { + f := func(f *Frame, v, w *Object) (*Object, *BaseException) { + if w.isInstance(IntType) { + w = intToLong(toIntUnsafe(w)).ToObject() + } else if !w.isInstance(LongType) { + return NotImplemented, nil + } + return longCallDivAndMod(fun, f, toLongUnsafe(w), toLongUnsafe(v)) + } + return &binaryOpSlot{f} +} + func longShiftOpSlot(fun func(z, x *big.Int, n uint)) *binaryOpSlot { f := func(f *Frame, v, w *Object) (*Object, *BaseException) { if w.isInstance(IntType) { diff --git a/runtime/long_test.go b/runtime/long_test.go index f031862a..fdef04b9 100644 --- a/runtime/long_test.go +++ b/runtime/long_test.go @@ -133,6 +133,20 @@ func TestLongBinaryOps(t *testing.T) { {Div, NewList().ToObject(), NewLong(big.NewInt(21)).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for /: 'list' and 'long'")}, {Div, 1, 0, nil, mustCreateException(ZeroDivisionErrorType, "integer division or modulo by zero")}, {Div, MinInt, -1, NewLong(new(big.Int).Neg(minIntBig)).ToObject(), nil}, + {DivMod, 7, 3, NewTuple2(NewLong(big.NewInt(2)).ToObject(), NewLong(big.NewInt(1)).ToObject()).ToObject(), nil}, + {DivMod, 3, -7, NewTuple2(NewLong(big.NewInt(-1)).ToObject(), NewLong(big.NewInt(-4)).ToObject()).ToObject(), nil}, + {DivMod, MaxInt, MinInt, NewTuple2(NewLong(big.NewInt(-1)).ToObject(), NewLong(big.NewInt(-1)).ToObject()).ToObject(), nil}, + {DivMod, MinInt, MaxInt, NewTuple2(NewLong(big.NewInt(-2)).ToObject(), NewLong(big.NewInt(MaxInt-1)).ToObject()).ToObject(), nil}, + {DivMod, MinInt, 1, NewTuple2(NewLong(big.NewInt(MinInt)).ToObject(), NewLong(big.NewInt(0)).ToObject()).ToObject(), nil}, + {DivMod, MinInt, -1, NewTuple2(NewLong(new(big.Int).Neg(minIntBig)).ToObject(), NewLong(big.NewInt(0)).ToObject()).ToObject(), nil}, + {DivMod, NewList().ToObject(), NewLong(big.NewInt(21)).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for divmod(): 'list' and 'long'")}, + {DivMod, 1, 0, nil, mustCreateException(ZeroDivisionErrorType, "integer division or modulo by zero")}, + {FloorDiv, 7, 3, NewLong(big.NewInt(2)).ToObject(), nil}, + {FloorDiv, MaxInt, MinInt, NewLong(big.NewInt(-1)).ToObject(), nil}, + {FloorDiv, MinInt, MaxInt, NewLong(big.NewInt(-2)).ToObject(), nil}, + {FloorDiv, NewList().ToObject(), NewLong(big.NewInt(21)).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for //: 'list' and 'long'")}, + {FloorDiv, 1, 0, nil, mustCreateException(ZeroDivisionErrorType, "integer division or modulo by zero")}, + {FloorDiv, MinInt, -1, NewLong(new(big.Int).Neg(minIntBig)).ToObject(), nil}, {LShift, 2, 4, NewLong(big.NewInt(32)).ToObject(), nil}, {LShift, 12, 10, NewLong(big.NewInt(12288)).ToObject(), nil}, {LShift, 10, 100, NewLong(new(big.Int).Lsh(big.NewInt(10), 100)).ToObject(), nil}, diff --git a/runtime/method.go b/runtime/method.go index f48c0e49..488aaef3 100644 --- a/runtime/method.go +++ b/runtime/method.go @@ -22,16 +22,10 @@ import ( // Method represents Python 'instancemethod' objects. type Method struct { Object - function *Function - self *Object - class *Type - name string `attr:"__name__"` -} - -// NewMethod returns a method wrapping the given function belonging to class. -// When self is None the method is unbound, otherwise it is bound to self. -func NewMethod(function *Function, self *Object, class *Type) *Method { - return &Method{Object{typ: MethodType}, function, self, class, function.Name()} + function *Object `attr:"im_func"` + self *Object `attr:"im_self"` + class *Object `attr:"im_class"` + name string `attr:"__name__"` } func toMethodUnsafe(o *Object) *Method { @@ -48,44 +42,136 @@ var MethodType = newBasisType("instancemethod", reflect.TypeOf(Method{}), toMeth func methodCall(f *Frame, callable *Object, args Args, kwargs KWArgs) (*Object, *BaseException) { m := toMethodUnsafe(callable) - var methodArgs []*Object argc := len(args) - if m.self == None { - if argc < 1 { - format := "unbound method %s() must be called with %s instance as first argument (got nothing instead)" - return nil, f.RaiseType(TypeErrorType, fmt.Sprintf(format, m.name, m.class.Name())) - } - if !args[0].isInstance(m.class) { - format := "unbound method %s() must be called with %s instance as first argument (got %s instance instead)" - return nil, f.RaiseType(TypeErrorType, fmt.Sprintf(format, m.name, m.class.Name(), args[0].typ.Name())) - } - methodArgs = args - } else { - methodArgs = make([]*Object, argc+1, argc+1) + if m.self != nil { + methodArgs := f.MakeArgs(argc + 1) methodArgs[0] = m.self copy(methodArgs[1:], args) + result, raised := m.function.Call(f, methodArgs, kwargs) + f.FreeArgs(methodArgs) + return result, raised + } + if argc < 1 { + className, raised := methodGetMemberName(f, m.class) + if raised != nil { + return nil, raised + } + format := "unbound method %s() must be called with %s " + + "instance as first argument (got nothing instead)" + return nil, f.RaiseType(TypeErrorType, fmt.Sprintf(format, m.name, className)) + } + // instancemethod.__new__ ensures that m.self and m.class are not both + // nil. Since m.self is nil, we know that m.class is not. + isInst, raised := IsInstance(f, args[0], m.class) + if raised != nil { + return nil, raised } - return m.function.Call(f, methodArgs, kwargs) + if !isInst { + className, raised := methodGetMemberName(f, m.class) + if raised != nil { + return nil, raised + } + format := "unbound method %s() must be called with %s " + + "instance as first argument (got %s instance instead)" + return nil, f.RaiseType(TypeErrorType, fmt.Sprintf(format, m.name, className, args[0].typ.Name())) + } + return m.function.Call(f, args, kwargs) +} + +func methodGet(f *Frame, desc, instance *Object, owner *Type) (*Object, *BaseException) { + m := toMethodUnsafe(desc) + if m.self != nil { + // Don't bind a method that's already bound. + return desc, nil + } + if m.class != nil { + subcls, raised := IsSubclass(f, owner.ToObject(), m.class) + if raised != nil { + return nil, raised + } + if !subcls { + // Don't bind if owner is not a subclass of m.class. + return desc, nil + } + } + return (&Method{Object{typ: MethodType}, m.function, instance, owner.ToObject(), m.name}).ToObject(), nil +} + +func methodNew(f *Frame, t *Type, args Args, _ KWArgs) (*Object, *BaseException) { + expectedTypes := []*Type{ObjectType, ObjectType, ObjectType} + argc := len(args) + if argc == 2 { + expectedTypes = expectedTypes[:2] + } + if raised := checkFunctionArgs(f, "__new__", args, expectedTypes...); raised != nil { + return nil, raised + } + function, self := args[0], args[1] + if self == None { + self = nil + } + var class *Object + if argc > 2 { + class = args[2] + } else if self == nil { + return nil, f.RaiseType(TypeErrorType, "unbound methods must have non-NULL im_class") + } + if function.Type().slots.Call == nil { + return nil, f.RaiseType(TypeErrorType, "first argument must be callable") + } + functionName, raised := methodGetMemberName(f, function) + if raised != nil { + return nil, raised + } + method := &Method{Object{typ: MethodType}, function, self, class, functionName} + return method.ToObject(), nil } func methodRepr(f *Frame, o *Object) (*Object, *BaseException) { m := toMethodUnsafe(o) s := "" - if m.self == None { - s = fmt.Sprintf("", m.class.Name(), m.function.Name()) + className, raised := methodGetMemberName(f, m.class) + if raised != nil { + return nil, raised + } + functionName, raised := methodGetMemberName(f, m.function) + if raised != nil { + return nil, raised + } + if m.self == nil { + s = fmt.Sprintf("", className, functionName) } else { repr, raised := Repr(f, m.self) if raised != nil { return nil, raised } - s = fmt.Sprintf("", m.class.Name(), m.function.Name(), repr.Value()) + s = fmt.Sprintf("", className, functionName, repr.Value()) } return NewStr(s).ToObject(), nil } func initMethodType(map[string]*Object) { - // TODO: Should be instantiable. - MethodType.flags &= ^(typeFlagBasetype | typeFlagInstantiable) + MethodType.flags &= ^typeFlagBasetype MethodType.slots.Call = &callSlot{methodCall} + MethodType.slots.Get = &getSlot{methodGet} MethodType.slots.Repr = &unaryOpSlot{methodRepr} + MethodType.slots.New = &newSlot{methodNew} +} + +func methodGetMemberName(f *Frame, o *Object) (string, *BaseException) { + if o == nil { + return "?", nil + } + name, raised := GetAttr(f, o, internedName, None) + if raised != nil { + return "", raised + } + if !name.isInstance(BaseStringType) { + return "?", nil + } + nameStr, raised := ToStr(f, name) + if raised != nil { + return "", raised + } + return nameStr.Value(), nil } diff --git a/runtime/method_test.go b/runtime/method_test.go index 1efa7160..98f27e4b 100644 --- a/runtime/method_test.go +++ b/runtime/method_test.go @@ -21,19 +21,20 @@ import ( func TestMethodCall(t *testing.T) { foo := newBuiltinFunction("foo", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { return NewTuple(args.makeCopy()...).ToObject(), nil - }) + }).ToObject() self := newObject(ObjectType) arg0 := NewInt(123).ToObject() arg1 := NewStr("abc").ToObject() cases := []invokeTestCase{ - {args: wrapArgs(NewMethod(foo, self, ObjectType)), want: NewTuple(self).ToObject()}, - {args: wrapArgs(NewMethod(foo, None, ObjectType), self), want: NewTuple(self).ToObject()}, - {args: wrapArgs(NewMethod(foo, self, ObjectType), arg0, arg1), want: NewTuple(self, arg0, arg1).ToObject()}, - {args: wrapArgs(NewMethod(foo, None, ObjectType), self, arg0, arg1), want: NewTuple(self, arg0, arg1).ToObject()}, + {args: wrapArgs(newTestMethod(foo, self, ObjectType.ToObject())), want: NewTuple(self).ToObject()}, + {args: wrapArgs(newTestMethod(foo, None, ObjectType.ToObject()), self), want: NewTuple(self).ToObject()}, + {args: wrapArgs(newTestMethod(foo, self, ObjectType.ToObject()), arg0, arg1), want: NewTuple(self, arg0, arg1).ToObject()}, + {args: wrapArgs(newTestMethod(foo, None, ObjectType.ToObject()), self, arg0, arg1), want: NewTuple(self, arg0, arg1).ToObject()}, {args: wrapArgs(), wantExc: mustCreateException(TypeErrorType, "unbound method __call__() must be called with instancemethod instance as first argument (got nothing instead)")}, {args: wrapArgs(newObject(ObjectType)), wantExc: mustCreateException(TypeErrorType, "unbound method __call__() must be called with instancemethod instance as first argument (got object instance instead)")}, - {args: wrapArgs(NewMethod(foo, None, IntType), newObject(ObjectType)), wantExc: mustCreateException(TypeErrorType, "unbound method foo() must be called with int instance as first argument (got object instance instead)")}, - {args: wrapArgs(NewMethod(foo, None, IntType)), wantExc: mustCreateException(TypeErrorType, "unbound method foo() must be called with int instance as first argument (got nothing instead)")}, + {args: wrapArgs(newTestMethod(foo, None, IntType.ToObject()), newObject(ObjectType)), wantExc: mustCreateException(TypeErrorType, "unbound method foo() must be called with int instance as first argument (got object instance instead)")}, + {args: wrapArgs(newTestMethod(foo, None, IntType.ToObject())), wantExc: mustCreateException(TypeErrorType, "unbound method foo() must be called with int instance as first argument (got nothing instead)")}, + {args: wrapArgs(newTestMethod(foo, None, None), None), wantExc: mustCreateException(TypeErrorType, "classinfo must be a type or tuple of types")}, } for _, cas := range cases { if err := runInvokeMethodTestCase(MethodType, "__call__", &cas); err != "" { @@ -42,12 +43,59 @@ func TestMethodCall(t *testing.T) { } } +func TestMethodGet(t *testing.T) { + get := mustNotRaise(GetAttr(NewRootFrame(), MethodType.ToObject(), NewStr("__get__"), nil)) + fun := wrapFuncForTest(func(f *Frame, args ...*Object) (*Object, *BaseException) { + o, raised := get.Call(f, args, nil) + if raised != nil { + return nil, raised + } + m := toMethodUnsafe(o) + self, class := m.self, m.class + if self == nil { + self = None + } + if class == nil { + class = None + } + return newTestTuple(m.function, self, class).ToObject(), nil + }) + dummyFunc := wrapFuncForTest(func() {}) + bound := mustNotRaise(MethodType.Call(NewRootFrame(), wrapArgs(dummyFunc, "foo"), nil)) + unbound := newTestMethod(dummyFunc, None, IntType.ToObject()) + cases := []invokeTestCase{ + {args: wrapArgs(bound, "bar", StrType), want: newTestTuple(dummyFunc, "foo", None).ToObject()}, + {args: wrapArgs(unbound, "bar", StrType), want: newTestTuple(dummyFunc, None, IntType).ToObject()}, + {args: wrapArgs(unbound, 123, IntType), want: newTestTuple(dummyFunc, 123, IntType).ToObject()}, + {args: wrapArgs(newTestMethod(dummyFunc, None, None), "bar", StrType), wantExc: mustCreateException(TypeErrorType, "classinfo must be a type or tuple of types")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } +} + +func TestMethodNew(t *testing.T) { + cases := []invokeTestCase{ + {wantExc: mustCreateException(TypeErrorType, "'__new__' requires 3 arguments")}, + {args: Args{None, None, None}, wantExc: mustCreateException(TypeErrorType, "first argument must be callable")}, + {args: Args{wrapFuncForTest(func() {}), None}, wantExc: mustCreateException(TypeErrorType, "unbound methods must have non-NULL im_class")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(MethodType.ToObject(), &cas); err != "" { + t.Error(err) + } + } +} + func TestMethodStrRepr(t *testing.T) { - foo := newBuiltinFunction("foo", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { return None, nil }) + foo := newBuiltinFunction("foo", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { return None, nil }).ToObject() cases := []invokeTestCase{ - {args: wrapArgs(NewMethod(foo, None, StrType)), want: NewStr("").ToObject()}, - {args: wrapArgs(NewMethod(foo, NewStr("wut").ToObject(), StrType)), want: NewStr("").ToObject()}, - {args: wrapArgs(NewMethod(foo, NewInt(123).ToObject(), TupleType)), want: NewStr("").ToObject()}, + {args: wrapArgs(newTestMethod(foo, None, StrType.ToObject())), want: NewStr("").ToObject()}, + {args: wrapArgs(newTestMethod(foo, NewStr("wut").ToObject(), StrType.ToObject())), want: NewStr("").ToObject()}, + {args: wrapArgs(newTestMethod(foo, NewInt(123).ToObject(), TupleType.ToObject())), want: NewStr("").ToObject()}, + {args: wrapArgs(newTestMethod(foo, None, None)), want: NewStr("").ToObject()}, } for _, cas := range cases { if err := runInvokeTestCase(wrapFuncForTest(ToStr), &cas); err != "" { @@ -58,3 +106,7 @@ func TestMethodStrRepr(t *testing.T) { } } } + +func newTestMethod(function, self, class *Object) *Method { + return toMethodUnsafe(mustNotRaise(MethodType.Call(NewRootFrame(), Args{function, self, class}, nil))) +} diff --git a/runtime/module.go b/runtime/module.go index 49aa0520..cfa2fb04 100644 --- a/runtime/module.go +++ b/runtime/module.go @@ -32,7 +32,8 @@ const ( ) var ( - importMutex sync.Mutex + importMutex sync.Mutex + moduleRegistry = map[string]*Code{} // ModuleType is the object representing the Python 'module' type. ModuleType = newBasisType("module", reflect.TypeOf(Module{}), toModuleUnsafe, ObjectType) // SysModules is the global dict of imported modules, aka sys.modules. @@ -44,12 +45,29 @@ type Module struct { Object mutex recursiveMutex state moduleState + code *Code } // ModuleInit functions are called when importing Grumpy modules to execute the // top level code for that module. type ModuleInit func(f *Frame, m *Module) *BaseException +// RegisterModule adds the named module to the registry so that it can be +// subsequently imported. +func RegisterModule(name string, c *Code) { + err := "" + importMutex.Lock() + if moduleRegistry[name] == nil { + moduleRegistry[name] = c + } else { + err = "module already registered: " + name + } + importMutex.Unlock() + if err != "" { + logFatal(err) + } +} + // ImportModule takes a fully qualified module name (e.g. a.b.c) and a slice of // code objects where the name of the i'th module is the prefix of name // ending in the i'th dot. The number of dot delimited parts of name must be the @@ -68,70 +86,24 @@ type ModuleInit func(f *Frame, m *Module) *BaseException // module, both invocations will produce the same module object and the module // is guaranteed to only be initialized once. The second invocation will not // return the module until it is fully initialized. -func ImportModule(f *Frame, name string, codeObjs []*Code) ([]*Object, *BaseException) { +func ImportModule(f *Frame, name string) ([]*Object, *BaseException) { + if strings.Contains(name, "/") { + o, raised := importOne(f, name) + if raised != nil { + return nil, raised + } + return []*Object{o}, nil + } parts := strings.Split(name, ".") numParts := len(parts) - if numParts != len(codeObjs) { - return nil, f.RaiseType(SystemErrorType, fmt.Sprintf("invalid import: %s", name)) - } result := make([]*Object, numParts) var prev *Object for i := 0; i < numParts; i++ { name := strings.Join(parts[:i+1], ".") - // We do very limited locking here resulting in some - // sys.modules consistency gotchas. - importMutex.Lock() - o, raised := SysModules.GetItemString(f, name) - if raised == nil && o == nil { - o = newModule(name, codeObjs[i].filename).ToObject() - raised = SysModules.SetItemString(f, name, o) - } - importMutex.Unlock() + o, raised := importOne(f, name) if raised != nil { return nil, raised } - if o.isInstance(ModuleType) { - var raised *BaseException - m := toModuleUnsafe(o) - m.mutex.Lock(f) - if m.state == moduleStateNew { - m.state = moduleStateInitializing - if _, raised = codeObjs[i].Eval(f, m.Dict(), nil, nil); raised == nil { - m.state = moduleStateReady - } else { - // If the module failed to initialize - // then before we relinquish the module - // lock, remove it from sys.modules. - // Threads waiting on this module will - // fail when they don't find it in - // sys.modules below. - e, tb := f.ExcInfo() - if _, raised := SysModules.DelItemString(f, name); raised != nil { - f.RestoreExc(e, tb) - } - } - } - m.mutex.Unlock(f) - if raised != nil { - return nil, raised - } - // The result should be what's in sys.modules, not - // necessarily the originally created module since this - // is CPython's behavior. - o, raised = SysModules.GetItemString(f, name) - if raised != nil { - return nil, raised - } - if o == nil { - // This can happen in the pathological case - // where the module clears itself from - // sys.modules during execution and is handled - // by CPython in PyImport_ExecCodeModuleEx in - // import.c. - format := "Loaded module %s not found in sys.modules" - return nil, f.RaiseType(ImportErrorType, fmt.Sprintf(format, name)) - } - } if prev != nil { if raised := SetAttr(f, prev, NewStr(parts[i]), o); raised != nil { return nil, raised @@ -143,40 +115,67 @@ func ImportModule(f *Frame, name string, codeObjs []*Code) ([]*Object, *BaseExce return result, nil } -// ImportNativeModule takes a fully qualified native module name (e.g. -// grumpy.native.fmt) and a mapping of module members that will be used to -// populate the module. The same logic is used as ImportModule for looking in -// sys.modules first. The last module created in this way is populated with the -// given members and returned. -func ImportNativeModule(f *Frame, name string, members map[string]*Object) (*Object, *BaseException) { - parts := strings.Split(name, ".") - numParts := len(parts) - var prev *Object - for i := 0; i < numParts; i++ { - name := strings.Join(parts[:i+1], ".") - importMutex.Lock() - o, raised := SysModules.GetItemString(f, name) - if raised == nil && o == nil { - o = newModule(name, "").ToObject() +func importOne(f *Frame, name string) (*Object, *BaseException) { + var c *Code + // We do very limited locking here resulting in some + // sys.modules consistency gotchas. + importMutex.Lock() + o, raised := SysModules.GetItemString(f, name) + if raised == nil && o == nil { + if c = moduleRegistry[name]; c == nil { + raised = f.RaiseType(ImportErrorType, name) + } else { + o = newModule(name, c.filename).ToObject() raised = SysModules.SetItemString(f, name, o) } - importMutex.Unlock() + } + importMutex.Unlock() + if raised != nil { + return nil, raised + } + if o.isInstance(ModuleType) { + var raised *BaseException + m := toModuleUnsafe(o) + m.mutex.Lock(f) + if m.state == moduleStateNew { + m.state = moduleStateInitializing + if _, raised = c.Eval(f, m.Dict(), nil, nil); raised == nil { + m.state = moduleStateReady + } else { + // If the module failed to initialize + // then before we relinquish the module + // lock, remove it from sys.modules. + // Threads waiting on this module will + // fail when they don't find it in + // sys.modules below. + e, tb := f.ExcInfo() + if _, raised := SysModules.DelItemString(f, name); raised != nil { + f.RestoreExc(e, tb) + } + } + } + m.mutex.Unlock(f) if raised != nil { return nil, raised } - if prev != nil { - if raised := SetAttr(f, prev, NewStr(parts[i]), o); raised != nil { - return nil, raised - } - } - prev = o - } - for k, v := range members { - if raised := SetAttr(f, prev, NewStr(k), v); raised != nil { + // The result should be what's in sys.modules, not + // necessarily the originally created module since this + // is CPython's behavior. + o, raised = SysModules.GetItemString(f, name) + if raised != nil { return nil, raised } + if o == nil { + // This can happen in the pathological case + // where the module clears itself from + // sys.modules during execution and is handled + // by CPython in PyImport_ExecCodeModuleEx in + // import.c. + format := "Loaded module %s not found in sys.modules" + return nil, f.RaiseType(ImportErrorType, fmt.Sprintf(format, name)) + } } - return prev, nil + return o, nil } // newModule creates a new Module object with the given fully qualified name @@ -209,7 +208,7 @@ func (m *Module) GetFilename(f *Frame) (*Str, *BaseException) { // GetName returns the __name__ attribute of m, raising SystemError if it does // not exist. func (m *Module) GetName(f *Frame) (*Str, *BaseException) { - nameAttr, raised := GetAttr(f, m.ToObject(), NewStr("__name__"), None) + nameAttr, raised := GetAttr(f, m.ToObject(), internedName, None) if raised != nil { return nil, raised } @@ -233,7 +232,7 @@ func moduleInit(f *Frame, o *Object, args Args, _ KWArgs) (*Object, *BaseExcepti if raised := checkFunctionArgs(f, "__init__", args, expectedTypes...); raised != nil { return nil, raised } - if raised := SetAttr(f, o, NewStr("__name__"), args[0]); raised != nil { + if raised := SetAttr(f, o, internedName, args[0]); raised != nil { return nil, raised } if argc > 1 { @@ -288,19 +287,17 @@ func RunMain(code *Code) int { m := newModule("__main__", code.filename) m.state = moduleStateInitializing f := NewRootFrame() + f.code = code + f.globals = m.Dict() if raised := SysModules.SetItemString(f, "__main__", m.ToObject()); raised != nil { - fmt.Fprint(os.Stderr, raised.String()) + Stderr.writeString(raised.String()) } - _, e := code.Eval(f, m.Dict(), nil, nil) + _, e := code.fn(f, nil) if e == nil { return 0 } if !e.isInstance(SystemExitType) { - s, raised := FormatException(f, e) - if raised != nil { - s = e.String() - } - fmt.Fprint(os.Stderr, s) + Stderr.writeString(FormatExc(f)) return 1 } f.RestoreExc(nil, nil) @@ -315,7 +312,7 @@ func RunMain(code *Code) int { return 0 } if s, raised := ToStr(f, o); raised == nil { - fmt.Fprintln(os.Stderr, s.Value()) + Stderr.writeString(s.Value() + "\n") } return 1 } diff --git a/runtime/module_test.go b/runtime/module_test.go index f9f41206..2edc4bbe 100644 --- a/runtime/module_test.go +++ b/runtime/module_test.go @@ -40,7 +40,7 @@ func TestImportModule(t *testing.T) { return nil, f.RaiseType(AssertionErrorType, "circular imported recursively") } circularImported = true - if _, raised := ImportModule(f, "circular", []*Code{fooCode}); raised != nil { + if _, raised := ImportModule(f, "circular"); raised != nil { return nil, raised } return None, nil @@ -55,41 +55,47 @@ func TestImportModule(t *testing.T) { // NOTE: This test progressively evolves sys.modules, checking after // each test case that it's populated appropriately. oldSysModules := SysModules + oldModuleRegistry := moduleRegistry defer func() { SysModules = oldSysModules + moduleRegistry = oldModuleRegistry }() SysModules = newStringDict(map[string]*Object{"invalid": invalidModule}) + moduleRegistry = map[string]*Code{ + "foo": fooCode, + "foo.bar": barCode, + "foo.bar.baz": bazCode, + "foo.qux": quxCode, + "raises": raisesCode, + "circular": circularCode, + "clear": clearCode, + } cases := []struct { name string - codeObjs []*Code want *Object wantExc *BaseException wantSysModules *Dict }{ { - "foo.bar", - []*Code{}, + "noexist", nil, - mustCreateException(SystemErrorType, "invalid import: foo.bar"), + mustCreateException(ImportErrorType, "noexist"), newStringDict(map[string]*Object{"invalid": invalidModule}), }, { "invalid", - []*Code{fooCode}, NewTuple(invalidModule).ToObject(), nil, newStringDict(map[string]*Object{"invalid": invalidModule}), }, { "raises", - []*Code{raisesCode}, nil, mustCreateException(ValueErrorType, "uh oh"), newStringDict(map[string]*Object{"invalid": invalidModule}), }, { "foo", - []*Code{fooCode}, NewTuple(foo.ToObject()).ToObject(), nil, newStringDict(map[string]*Object{ @@ -99,7 +105,6 @@ func TestImportModule(t *testing.T) { }, { "foo", - []*Code{fooCode}, NewTuple(foo.ToObject()).ToObject(), nil, newStringDict(map[string]*Object{ @@ -109,7 +114,6 @@ func TestImportModule(t *testing.T) { }, { "foo.qux", - []*Code{fooCode, quxCode}, NewTuple(foo.ToObject(), qux.ToObject()).ToObject(), nil, newStringDict(map[string]*Object{ @@ -120,7 +124,6 @@ func TestImportModule(t *testing.T) { }, { "foo.bar.baz", - []*Code{fooCode, barCode, bazCode}, NewTuple(foo.ToObject(), bar.ToObject(), baz.ToObject()).ToObject(), nil, newStringDict(map[string]*Object{ @@ -133,7 +136,6 @@ func TestImportModule(t *testing.T) { }, { "circular", - []*Code{circularCode}, NewTuple(circularTestModule).ToObject(), nil, newStringDict(map[string]*Object{ @@ -147,7 +149,6 @@ func TestImportModule(t *testing.T) { }, { "clear", - []*Code{clearCode}, nil, mustCreateException(ImportErrorType, "Loaded module clear not found in sys.modules"), newStringDict(map[string]*Object{ @@ -161,7 +162,7 @@ func TestImportModule(t *testing.T) { }, } for _, cas := range cases { - mods, raised := ImportModule(f, cas.name, cas.codeObjs) + mods, raised := ImportModule(f, cas.name) var got *Object if raised == nil { got = NewTuple(mods...).ToObject() @@ -184,26 +185,6 @@ func TestImportModule(t *testing.T) { } } -func TestImportNativeModule(t *testing.T) { - f := NewRootFrame() - oldSysModules := SysModules - defer func() { - SysModules = oldSysModules - }() - SysModules = NewDict() - bar := newObject(ObjectType) - o := mustNotRaise(ImportNativeModule(f, "grumpy.native.foo", map[string]*Object{"Bar": bar})) - if !o.isInstance(ModuleType) { - t.Errorf(`ImportNativeModule("grumpy.native.foo") returned %v, want module`, o) - } else if nameAttr := mustNotRaise(GetAttr(f, o, NewStr("__name__"), None)); !nameAttr.isInstance(StrType) { - t.Errorf(`ImportNativeModule("grumpy.native.foo") returned module with non-string name %v`, nameAttr) - } else if gotName := toStrUnsafe(nameAttr).Value(); gotName != "grumpy.native.foo" { - t.Errorf(`ImportNativeModule("grumpy.native.foo") returned module named %q, want "grumpy.native.foo"`, gotName) - } else if gotBar := mustNotRaise(GetAttr(f, o, NewStr("Bar"), None)); gotBar != bar { - t.Errorf("foo.Bar = %v, want %v", gotBar, bar) - } -} - func TestModuleGetNameAndFilename(t *testing.T) { fun := wrapFuncForTest(func(f *Frame, m *Module) (*Tuple, *BaseException) { name, raised := m.GetName(f) @@ -234,7 +215,7 @@ func TestModuleInit(t *testing.T) { if raised != nil { return nil, raised } - name, raised := GetAttr(f, o, NewStr("__name__"), None) + name, raised := GetAttr(f, o, internedName, None) if raised != nil { return nil, raised } @@ -307,15 +288,15 @@ func TestRunMain(t *testing.T) { } func runMainAndCaptureStderr(code *Code) (int, string, error) { - oldStderr := os.Stderr + oldStderr := Stderr defer func() { - os.Stderr = oldStderr + Stderr = oldStderr }() r, w, err := os.Pipe() if err != nil { return 0, "", err } - os.Stderr = w + Stderr = NewFileFromFD(w.Fd(), nil) c := make(chan int) go func() { defer w.Close() diff --git a/runtime/native.go b/runtime/native.go index ce5bbe17..7e66edea 100644 --- a/runtime/native.go +++ b/runtime/native.go @@ -34,6 +34,8 @@ var ( // these kinds of values resolve directly to primitive Python types. nativeTypes = map[reflect.Type]*Type{ reflect.TypeOf(bool(false)): BoolType, + reflect.TypeOf(complex64(0)): ComplexType, + reflect.TypeOf(complex128(0)): ComplexType, reflect.TypeOf(float32(0)): FloatType, reflect.TypeOf(float64(0)): FloatType, reflect.TypeOf(int(0)): IntType, @@ -202,7 +204,7 @@ func nativeFuncGetName(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) } func nativeFuncRepr(f *Frame, o *Object) (*Object, *BaseException) { - name, raised := GetAttr(f, o, NewStr("__name__"), NewStr("").ToObject()) + name, raised := GetAttr(f, o, internedName, NewStr("").ToObject()) if raised != nil { return nil, raised } @@ -220,12 +222,133 @@ func initNativeFuncType(dict map[string]*Object) { nativeFuncType.slots.Repr = &unaryOpSlot{nativeFuncRepr} } +func nativeSliceGetItem(f *Frame, o, key *Object) (*Object, *BaseException) { + v := toNativeUnsafe(o).value + if key.typ.slots.Index != nil { + elem, raised := nativeSliceGetIndex(f, v, key) + if raised != nil { + return nil, raised + } + return WrapNative(f, elem) + } + if !key.isInstance(SliceType) { + return nil, f.RaiseType(TypeErrorType, fmt.Sprintf("native slice indices must be integers, not %s", key.typ.Name())) + } + s := toSliceUnsafe(key) + start, stop, step, sliceLen, raised := s.calcSlice(f, v.Len()) + if raised != nil { + return nil, raised + } + if step == 1 { + return WrapNative(f, v.Slice(start, stop)) + } + result := reflect.MakeSlice(v.Type(), sliceLen, sliceLen) + i := 0 + for j := start; j != stop; j += step { + resultElem := result.Index(i) + resultElem.Set(v.Index(j)) + i++ + } + return WrapNative(f, result) +} + func nativeSliceIter(f *Frame, o *Object) (*Object, *BaseException) { return newSliceIterator(toNativeUnsafe(o).value), nil } +func nativeSliceLen(f *Frame, o *Object) (*Object, *BaseException) { + return NewInt(toNativeUnsafe(o).value.Len()).ToObject(), nil +} + +func nativeSliceRepr(f *Frame, o *Object) (*Object, *BaseException) { + v := toNativeUnsafe(o).value + typeName := nativeTypeName(v.Type()) + if f.reprEnter(o) { + return NewStr(fmt.Sprintf("%s{...}", typeName)).ToObject(), nil + } + defer f.reprLeave(o) + numElems := v.Len() + elems := make([]*Object, numElems) + for i := 0; i < numElems; i++ { + elem, raised := WrapNative(f, v.Index(i)) + if raised != nil { + return nil, raised + } + elems[i] = elem + } + repr, raised := seqRepr(f, elems) + if raised != nil { + return nil, raised + } + return NewStr(fmt.Sprintf("%s{%s}", typeName, repr)).ToObject(), nil +} + +func nativeSliceSetItem(f *Frame, o, key, value *Object) *BaseException { + v := toNativeUnsafe(o).value + elemType := v.Type().Elem() + if key.typ.slots.Int != nil { + elem, raised := nativeSliceGetIndex(f, v, key) + if raised != nil { + return raised + } + if !elem.CanSet() { + return f.RaiseType(TypeErrorType, "cannot set slice element") + } + elemVal, raised := maybeConvertValue(f, value, elemType) + if raised != nil { + return raised + } + elem.Set(elemVal) + return nil + } + if key.isInstance(SliceType) { + s := toSliceUnsafe(key) + start, stop, step, sliceLen, raised := s.calcSlice(f, v.Len()) + if raised != nil { + return raised + } + if !v.Index(start).CanSet() { + return f.RaiseType(TypeErrorType, "cannot set slice element") + } + return seqApply(f, value, func(elems []*Object, _ bool) *BaseException { + numElems := len(elems) + if sliceLen != numElems { + format := "attempt to assign sequence of size %d to slice of size %d" + return f.RaiseType(ValueErrorType, fmt.Sprintf(format, numElems, sliceLen)) + } + i := 0 + for j := start; j != stop; j += step { + elemVal, raised := maybeConvertValue(f, elems[i], elemType) + if raised != nil { + return raised + } + v.Index(j).Set(elemVal) + i++ + } + return nil + }) + } + return f.RaiseType(TypeErrorType, fmt.Sprintf("native slice indices must be integers, not %s", key.Type().Name())) +} + func initNativeSliceType(map[string]*Object) { + nativeSliceType.slots.GetItem = &binaryOpSlot{nativeSliceGetItem} nativeSliceType.slots.Iter = &unaryOpSlot{nativeSliceIter} + nativeSliceType.slots.Len = &unaryOpSlot{nativeSliceLen} + nativeSliceType.slots.Repr = &unaryOpSlot{nativeSliceRepr} + nativeSliceType.slots.SetItem = &setItemSlot{nativeSliceSetItem} +} + +func nativeSliceGetIndex(f *Frame, slice reflect.Value, key *Object) (reflect.Value, *BaseException) { + i, raised := IndexInt(f, key) + if raised != nil { + return reflect.Value{}, raised + } + i, raised = seqCheckedIndex(f, slice.Len(), i) + if raised != nil { + return reflect.Value{}, raised + } + return slice.Index(i), nil } type sliceIterator struct { @@ -312,7 +435,10 @@ func WrapNative(f *Frame, v reflect.Value) (*Object, *BaseException) { // TODO: Make native bool subtypes singletons and add support // for __new__ so we can use t.Call() here. return (&Int{Object{typ: t}, i}).ToObject(), nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16: + case reflect.Complex64: + case reflect.Complex128: + return t.Call(f, Args{NewComplex(v.Complex()).ToObject()}, nil) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: return t.Call(f, Args{NewInt(int(v.Int())).ToObject()}, nil) // Handle potentially large ints separately in case of overflow. case reflect.Int64: @@ -321,7 +447,7 @@ func WrapNative(f *Frame, v reflect.Value) (*Object, *BaseException) { return NewLong(big.NewInt(i)).ToObject(), nil } return t.Call(f, Args{NewInt(int(i)).ToObject()}, nil) - case reflect.Uint, reflect.Uint32, reflect.Uint64: + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: i := v.Uint() if i > uint64(MaxInt) { return t.Call(f, Args{NewLong((new(big.Int).SetUint64(i))).ToObject()}, nil) @@ -391,13 +517,15 @@ func getNativeType(rtype reflect.Type) *Type { // object. base := nativeType switch rtype.Kind() { + case reflect.Complex64, reflect.Complex128: + base = ComplexType case reflect.Float32, reflect.Float64: base = FloatType case reflect.Func: base = nativeFuncType case reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int8, reflect.Int, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint8, reflect.Uint, reflect.Uintptr: base = IntType - case reflect.Slice: + case reflect.Array, reflect.Slice: base = nativeSliceType case reflect.String: base = StrType @@ -427,7 +555,7 @@ func getNativeType(rtype reflect.Type) *Type { d[name] = newNativeField(name, i, t) } } - t.dict = newStringDict(d) + t.setDict(newStringDict(d)) // This cannot fail since we're defining simple classes. if err := prepareType(t); err != "" { logFatal(err) @@ -439,7 +567,7 @@ func getNativeType(rtype reflect.Type) *Type { } func newNativeField(name string, i int, t *Type) *Object { - nativeFieldGet := func(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + get := newBuiltinFunction(name, func(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { if raised := checkFunctionArgs(f, name, args, t); raised != nil { return nil, raised } @@ -448,9 +576,28 @@ func newNativeField(name string, i int, t *Type) *Object { v = v.Elem() } return WrapNative(f, v.Field(i)) - } - get := newBuiltinFunction(name, nativeFieldGet).ToObject() - return newProperty(get, nil, nil).ToObject() + }).ToObject() + set := newBuiltinFunction(name, func(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + if raised := checkFunctionArgs(f, name, args, t, ObjectType); raised != nil { + return nil, raised + } + v := toNativeUnsafe(args[0]).value + for v.Type().Kind() == reflect.Ptr { + v = v.Elem() + } + field := v.Field(i) + if !field.CanSet() { + msg := fmt.Sprintf("cannot set field '%s' of type '%s'", name, t.Name()) + return nil, f.RaiseType(TypeErrorType, msg) + } + v, raised := maybeConvertValue(f, args[1], field.Type()) + if raised != nil { + return nil, raised + } + field.Set(v) + return None, nil + }).ToObject() + return newProperty(get, set, nil).ToObject() } func newNativeMethod(name string, fun reflect.Value) *Object { @@ -472,7 +619,7 @@ func maybeConvertValue(f *Frame, o *Object, expectedRType reflect.Type) (reflect case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer: return reflect.Zero(expectedRType), nil default: - return reflect.Value{}, f.RaiseType(TypeErrorType, fmt.Sprintf("cannot convert None to %s", expectedRType)) + return reflect.Value{}, f.RaiseType(TypeErrorType, fmt.Sprintf("an %s is required", expectedRType)) } } val, raised := ToNative(f, o) @@ -494,7 +641,7 @@ func maybeConvertValue(f *Frame, o *Object, expectedRType reflect.Type) (reflect } break } - return reflect.Value{}, f.RaiseType(TypeErrorType, fmt.Sprintf("cannot convert %s to %s", rtype, expectedRType)) + return reflect.Value{}, f.RaiseType(TypeErrorType, fmt.Sprintf("an %s is required", expectedRType)) } func nativeFuncTypeName(rtype reflect.Type) string { @@ -558,10 +705,12 @@ func nativeInvoke(f *Frame, fun reflect.Value, args Args) (ret *Object, raised * } } } + origExc, origTb := f.RestoreExc(nil, nil) result := fun.Call(nativeArgs) if e, _ := f.ExcInfo(); e != nil { return nil, e } + f.RestoreExc(origExc, origTb) numResults := len(result) if numResults > 0 && result[numResults-1].Type() == reflect.TypeOf((*BaseException)(nil)) { numResults-- diff --git a/runtime/native_test.go b/runtime/native_test.go index 3c69e636..6419f785 100644 --- a/runtime/native_test.go +++ b/runtime/native_test.go @@ -16,6 +16,7 @@ package grumpy import ( "errors" + "fmt" "math/big" "reflect" "regexp" @@ -80,7 +81,7 @@ func TestNativeFuncCall(t *testing.T) { func TestNativeFuncName(t *testing.T) { re := regexp.MustCompile(`(\w+\.)*\w+$`) fun := wrapFuncForTest(func(f *Frame, o *Object) (string, *BaseException) { - desc, raised := GetItem(f, nativeFuncType.Dict().ToObject(), NewStr("__name__").ToObject()) + desc, raised := GetItem(f, nativeFuncType.Dict().ToObject(), internedName.ToObject()) if raised != nil { return "", raised } @@ -397,7 +398,7 @@ func TestMaybeConvertValue(t *testing.T) { {NewFloat(0.5).ToObject(), reflect.TypeOf(float32(0)), float32(0.5), nil}, {fooNative.ToObject(), reflect.TypeOf(&fooStruct{}), foo, nil}, {None, reflect.TypeOf((*int)(nil)), (*int)(nil), nil}, - {None, reflect.TypeOf(""), nil, mustCreateException(TypeErrorType, "cannot convert None to string")}, + {None, reflect.TypeOf(""), nil, mustCreateException(TypeErrorType, "an string is required")}, } for _, cas := range cases { fun := wrapFuncForTest(func(f *Frame) *BaseException { @@ -479,11 +480,11 @@ func TestNewNativeFieldChecksInstanceType(t *testing.T) { } // When its field property is assigned to a different type - property, raised := native.typ.dict.GetItemString(f, "foo") + property, raised := native.typ.Dict().GetItemString(f, "foo") if raised != nil { t.Fatal("Unexpected exception:", raised) } - if raised := IntType.dict.SetItemString(f, "foo", property); raised != nil { + if raised := IntType.Dict().SetItemString(f, "foo", property); raised != nil { t.Fatal("Unexpected exception:", raised) } @@ -496,6 +497,182 @@ func TestNewNativeFieldChecksInstanceType(t *testing.T) { } } +func TestNativeSliceGetItem(t *testing.T) { + testRange := make([]int, 20) + for i := 0; i < len(testRange); i++ { + testRange[i] = i + } + badIndexType := newTestClass("badIndex", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__index__": newBuiltinFunction("__index__", func(f *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { + return nil, f.RaiseType(ValueErrorType, "wut") + }).ToObject(), + })) + cases := []invokeTestCase{ + {args: wrapArgs(testRange, 0), want: NewInt(0).ToObject()}, + {args: wrapArgs(testRange, 19), want: NewInt(19).ToObject()}, + {args: wrapArgs([]struct{}{}, 101), wantExc: mustCreateException(IndexErrorType, "index out of range")}, + {args: wrapArgs([]bool{true}, None), wantExc: mustCreateException(TypeErrorType, "native slice indices must be integers, not NoneType")}, + {args: wrapArgs(testRange, newObject(badIndexType)), wantExc: mustCreateException(ValueErrorType, "wut")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(GetItem), &cas); err != "" { + t.Error(err) + } + } +} + +func TestNativeSliceGetItemSlice(t *testing.T) { + fun := wrapFuncForTest(func(f *Frame, o *Object, slice *Slice, want interface{}) *BaseException { + item, raised := GetItem(f, o, slice.ToObject()) + if raised != nil { + return raised + } + val, raised := ToNative(f, item) + if raised != nil { + return raised + } + v := val.Interface() + msg := fmt.Sprintf("%v[%v] = %v, want %v", o, slice, v, want) + return Assert(f, GetBool(reflect.DeepEqual(v, want)).ToObject(), NewStr(msg).ToObject()) + }) + type fooStruct struct { + Bar int + } + cases := []invokeTestCase{ + {args: wrapArgs([]string{}, newTestSlice(50, 100), []string{}), want: None}, + {args: wrapArgs([]int{1, 2, 3, 4, 5}, newTestSlice(1, 3, None), []int{2, 3}), want: None}, + {args: wrapArgs([]fooStruct{fooStruct{1}, fooStruct{10}}, newTestSlice(-1, None, None), []fooStruct{fooStruct{10}}), want: None}, + {args: wrapArgs([]int{1, 2, 3, 4, 5}, newTestSlice(1, None, 2), []int{2, 4}), want: None}, + {args: wrapArgs([]float64{1.0, 2.0, 3.0, 4.0, 5.0}, newTestSlice(big.NewInt(1), None, 2), []float64{2.0, 4.0}), want: None}, + {args: wrapArgs([]string{"1", "2", "3", "4", "5"}, newTestSlice(1, big.NewInt(5), 2), []string{"2", "4"}), want: None}, + {args: wrapArgs([]int{1, 2, 3, 4, 5}, newTestSlice(1, None, big.NewInt(2)), []int{2, 4}), want: None}, + {args: wrapArgs([]int16{1, 2, 3, 4, 5}, newTestSlice(1.0, 3, None), None), wantExc: mustCreateException(TypeErrorType, errBadSliceIndex)}, + {args: wrapArgs([]byte{1, 2, 3}, newTestSlice(1, None, 0), None), wantExc: mustCreateException(ValueErrorType, "slice step cannot be zero")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } +} + +func TestNativeSliceLen(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs([]string{"foo", "bar"}), want: NewInt(2).ToObject()}, + {args: wrapArgs(make([]int, 100)), want: NewInt(100).ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(Len), &cas); err != "" { + t.Error(err) + } + } +} + +func TestNativeSliceStrRepr(t *testing.T) { + slice := make([]*Object, 2) + o := mustNotRaise(WrapNative(NewRootFrame(), reflect.ValueOf(slice))) + slice[0] = o + slice[1] = NewStr("foo").ToObject() + cases := []invokeTestCase{ + {args: wrapArgs([]string{"foo", "bar"}), want: NewStr("[]string{'foo', 'bar'}").ToObject()}, + {args: wrapArgs([]uint16{123}), want: NewStr("[]uint16{123}").ToObject()}, + {args: wrapArgs(o), want: NewStr("[]*Object{[]*Object{...}, 'foo'}").ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(wrapFuncForTest(ToStr), &cas); err != "" { + t.Error(err) + } + if err := runInvokeTestCase(wrapFuncForTest(Repr), &cas); err != "" { + t.Error(err) + } + } +} + +func TestNativeSliceSetItemSlice(t *testing.T) { + fun := wrapFuncForTest(func(f *Frame, o, index, value *Object, want interface{}) *BaseException { + originalStr := o.String() + if raised := SetItem(f, o, index, value); raised != nil { + return raised + } + val, raised := ToNative(f, o) + if raised != nil { + return raised + } + v := val.Interface() + msg := fmt.Sprintf("%v[%v] = %v -> %v, want %v", originalStr, index, value, o, want) + return Assert(f, GetBool(reflect.DeepEqual(v, want)).ToObject(), NewStr(msg).ToObject()) + }) + type fooStruct struct { + bar []int + } + foo := fooStruct{[]int{1, 2, 3}} + bar := mustNotRaise(WrapNative(NewRootFrame(), reflect.ValueOf(foo).Field(0))) + cases := []invokeTestCase{ + {args: wrapArgs([]string{"foo", "bar"}, 1, "baz", []string{"foo", "baz"}), want: None}, + {args: wrapArgs([]uint16{1, 2, 3}, newTestSlice(1), newTestList(4), []uint16{4, 2, 3}), want: None}, + {args: wrapArgs([]int{1, 2, 4, 5}, newTestSlice(1, None, 2), newTestTuple(10, 20), []int{1, 10, 4, 20}), want: None}, + {args: wrapArgs([]float64{}, newTestSlice(4, 8, 0), NewList(), None), wantExc: mustCreateException(ValueErrorType, "slice step cannot be zero")}, + {args: wrapArgs([]string{"foo", "bar"}, -100, None, None), wantExc: mustCreateException(IndexErrorType, "index out of range")}, + {args: wrapArgs([]int{}, 101, None, None), wantExc: mustCreateException(IndexErrorType, "index out of range")}, + {args: wrapArgs([]bool{true}, None, false, None), wantExc: mustCreateException(TypeErrorType, "native slice indices must be integers, not NoneType")}, + {args: wrapArgs([]int8{1, 2, 3}, newTestSlice(0), []int8{0}, []int8{0, 1, 2, 3}), wantExc: mustCreateException(ValueErrorType, "attempt to assign sequence of size 1 to slice of size 0")}, + {args: wrapArgs([]int{1, 2, 3}, newTestSlice(2, None), newTestList("foo"), None), wantExc: mustCreateException(TypeErrorType, "an int is required")}, + {args: wrapArgs(bar, 1, 42, None), wantExc: mustCreateException(TypeErrorType, "cannot set slice element")}, + {args: wrapArgs(bar, newTestSlice(1), newTestList(42), None), wantExc: mustCreateException(TypeErrorType, "cannot set slice element")}, + {args: wrapArgs([]string{"foo", "bar"}, 1, 123.0, None), wantExc: mustCreateException(TypeErrorType, "an string is required")}, + {args: wrapArgs([]string{"foo", "bar"}, 1, 123.0, None), wantExc: mustCreateException(TypeErrorType, "an string is required")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } +} + +func TestNativeStructFieldGet(t *testing.T) { + fun := wrapFuncForTest(func(f *Frame, o *Object, attr *Str) (*Object, *BaseException) { + return GetAttr(f, o, attr, nil) + }) + type fooStruct struct { + bar int + Baz float64 + } + cases := []invokeTestCase{ + {args: wrapArgs(fooStruct{bar: 1}, "bar"), want: NewInt(1).ToObject()}, + {args: wrapArgs(&fooStruct{Baz: 3.14}, "Baz"), want: NewFloat(3.14).ToObject()}, + {args: wrapArgs(fooStruct{}, "qux"), wantExc: mustCreateException(AttributeErrorType, `'fooStruct' object has no attribute 'qux'`)}, + } + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } +} + +func TestNativeStructFieldSet(t *testing.T) { + fun := wrapFuncForTest(func(f *Frame, o *Object, attr *Str, value *Object) (*Object, *BaseException) { + if raised := SetAttr(f, o, attr, value); raised != nil { + return nil, raised + } + return GetAttr(f, o, attr, nil) + }) + type fooStruct struct { + bar int + Baz float64 + } + cases := []invokeTestCase{ + {args: wrapArgs(&fooStruct{}, "Baz", 1.5), want: NewFloat(1.5).ToObject()}, + {args: wrapArgs(fooStruct{}, "bar", 123), wantExc: mustCreateException(TypeErrorType, `cannot set field 'bar' of type 'fooStruct'`)}, + {args: wrapArgs(fooStruct{}, "qux", "abc"), wantExc: mustCreateException(AttributeErrorType, `'fooStruct' has no attribute 'qux'`)}, + {args: wrapArgs(&fooStruct{}, "Baz", "abc"), wantExc: mustCreateException(TypeErrorType, "an float64 is required")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } +} + func wrapArgs(elems ...interface{}) Args { f := NewRootFrame() argc := len(elems) diff --git a/runtime/object.go b/runtime/object.go index 585c957f..0d343212 100644 --- a/runtime/object.go +++ b/runtime/object.go @@ -17,6 +17,7 @@ package grumpy import ( "fmt" "reflect" + "sync/atomic" "unsafe" ) @@ -39,7 +40,7 @@ var ( // Object represents Python 'object' objects. type Object struct { typ *Type `attr:"__class__"` - dict *Dict `attr:"__dict__"` + dict *Dict ref *WeakRef } @@ -50,7 +51,7 @@ func newObject(t *Type) *Object { } o := (*Object)(unsafe.Pointer(reflect.New(t.basis).Pointer())) o.typ = t - o.dict = dict + o.setDict(dict) return o } @@ -66,7 +67,13 @@ func (o *Object) Call(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseExcepti // Dict returns o's object dict, aka __dict__. func (o *Object) Dict() *Dict { - return o.dict + p := (*unsafe.Pointer)(unsafe.Pointer(&o.dict)) + return (*Dict)(atomic.LoadPointer(p)) +} + +func (o *Object) setDict(d *Dict) { + p := (*unsafe.Pointer)(unsafe.Pointer(&o.dict)) + atomic.StorePointer(p, unsafe.Pointer(d)) } // String returns a string representation of o, e.g. for debugging. @@ -109,8 +116,9 @@ func objectDelAttr(f *Frame, o *Object, name *Str) *BaseException { } } deleted := false - if o.dict != nil { - deleted, raised = o.dict.DelItem(f, name.ToObject()) + d := o.Dict() + if d != nil { + deleted, raised = d.DelItem(f, name.ToObject()) if raised != nil { return raised } @@ -138,7 +146,7 @@ func objectGetAttribute(f *Frame, o *Object, name *Str) (*Object, *BaseException } } // Look in the object's dict. - if d := o.dict; d != nil { + if d := o.Dict(); d != nil { value, raised := d.GetItem(f, name.ToObject()) if value != nil || raised != nil { return value, raised @@ -208,8 +216,8 @@ func objectSetAttr(f *Frame, o *Object, name *Str, value *Object) *BaseException return typeSet.Fn(f, typeAttr, o, value) } } - if o.dict != nil { - if raised := o.dict.SetItem(f, name.ToObject(), value); raised == nil || !raised.isInstance(KeyErrorType) { + if d := o.Dict(); d != nil { + if raised := d.SetItem(f, name.ToObject(), value); raised == nil || !raised.isInstance(KeyErrorType) { return nil } } @@ -220,6 +228,7 @@ func initObjectType(dict map[string]*Object) { ObjectType.typ = TypeType dict["__reduce__"] = objectReduceFunc dict["__reduce_ex__"] = newBuiltinFunction("__reduce_ex__", objectReduceEx).ToObject() + dict["__dict__"] = newProperty(newBuiltinFunction("_get_dict", objectGetDict).ToObject(), newBuiltinFunction("_set_dict", objectSetDict).ToObject(), nil).ToObject() ObjectType.slots.DelAttr = &delAttrSlot{objectDelAttr} ObjectType.slots.GetAttribute = &getAttributeSlot{objectGetAttribute} ObjectType.slots.Hash = &unaryOpSlot{objectHash} @@ -283,11 +292,11 @@ func objectReduceCommon(f *Frame, args Args) (*Object, *BaseException) { return nil, raised } } - newArgs := NewTuple(t.ToObject(), basisType.ToObject(), state).ToObject() + newArgs := NewTuple3(t.ToObject(), basisType.ToObject(), state).ToObject() if d := o.Dict(); d != nil { - return NewTuple(objectReconstructorFunc, newArgs, d.ToObject()).ToObject(), nil + return NewTuple3(objectReconstructorFunc, newArgs, d.ToObject()).ToObject(), nil } - return NewTuple(objectReconstructorFunc, newArgs).ToObject(), nil + return NewTuple2(objectReconstructorFunc, newArgs).ToObject(), nil } newArgs := []*Object{t.ToObject()} getNewArgsMethod, raised := GetAttr(f, o, NewStr("__getnewargs__"), None) @@ -306,8 +315,8 @@ func objectReduceCommon(f *Frame, args Args) (*Object, *BaseException) { newArgs = append(newArgs, toTupleUnsafe(extraNewArgs).elems...) } dict := None - if o.dict != nil { - dict = o.dict.ToObject() + if d := o.Dict(); d != nil { + dict = d.ToObject() } // For proto >= 2 include list and dict items. listItems := None @@ -332,5 +341,31 @@ func objectReduceCommon(f *Frame, args Args) (*Object, *BaseException) { if raised != nil { return nil, raised } - return NewTuple(newFunc, NewTuple(newArgs...).ToObject(), dict, listItems, dictItems).ToObject(), nil + return NewTuple5(newFunc, NewTuple(newArgs...).ToObject(), dict, listItems, dictItems).ToObject(), nil +} + +func objectGetDict(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "_get_dict", args, ObjectType); raised != nil { + return nil, raised + } + o := args[0] + d := o.Dict() + if d == nil { + format := "'%s' object has no attribute '__dict__'" + return nil, f.RaiseType(AttributeErrorType, fmt.Sprintf(format, o.typ.Name())) + } + return args[0].Dict().ToObject(), nil +} + +func objectSetDict(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "_set_dict", args, ObjectType, DictType); raised != nil { + return nil, raised + } + o := args[0] + if o.Type() == ObjectType { + format := "'%s' object has no attribute '__dict__'" + return nil, f.RaiseType(AttributeErrorType, fmt.Sprintf(format, o.typ.Name())) + } + o.setDict(toDictUnsafe(args[1])) + return None, nil } diff --git a/runtime/object_test.go b/runtime/object_test.go index 22b6e3cd..2065f6d6 100644 --- a/runtime/object_test.go +++ b/runtime/object_test.go @@ -105,7 +105,7 @@ func TestObjectDelAttr(t *testing.T) { }) dellerType := newTestClass("Deller", []*Type{ObjectType}, newStringDict(map[string]*Object{ "__get__": newBuiltinFunction("__get__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { - attr, raised := args[1].dict.GetItemString(f, "attr") + attr, raised := args[1].Dict().GetItemString(f, "attr") if raised != nil { return nil, raised } @@ -115,7 +115,7 @@ func TestObjectDelAttr(t *testing.T) { return attr, nil }).ToObject(), "__delete__": newBuiltinFunction("__delete__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { - deleted, raised := args[1].dict.DelItemString(f, "attr") + deleted, raised := args[1].Dict().DelItemString(f, "attr") if raised != nil { return nil, raised } @@ -127,7 +127,7 @@ func TestObjectDelAttr(t *testing.T) { })) fooType := newTestClass("Foo", []*Type{ObjectType}, newStringDict(map[string]*Object{"deller": newObject(dellerType)})) foo := newObject(fooType) - if raised := foo.dict.SetItemString(NewRootFrame(), "attr", NewInt(123).ToObject()); raised != nil { + if raised := foo.Dict().SetItemString(NewRootFrame(), "attr", NewInt(123).ToObject()); raised != nil { t.Fatal(raised) } cases := []invokeTestCase{ @@ -180,13 +180,13 @@ func TestObjectGetAttribute(t *testing.T) { "barsetter": setter, })) foo := newObject(fooType) - if raised := foo.dict.SetItemString(NewRootFrame(), "fooattr", True.ToObject()); raised != nil { + if raised := foo.Dict().SetItemString(NewRootFrame(), "fooattr", True.ToObject()); raised != nil { t.Fatal(raised) } - if raised := foo.dict.SetItemString(NewRootFrame(), "barattr", NewInt(-1).ToObject()); raised != nil { + if raised := foo.Dict().SetItemString(NewRootFrame(), "barattr", NewInt(-1).ToObject()); raised != nil { t.Fatal(raised) } - if raised := foo.dict.SetItemString(NewRootFrame(), "barsetter", NewStr("NOT setter").ToObject()); raised != nil { + if raised := foo.Dict().SetItemString(NewRootFrame(), "barsetter", NewStr("NOT setter").ToObject()); raised != nil { t.Fatal(raised) } cases := []invokeTestCase{ @@ -205,6 +205,52 @@ func TestObjectGetAttribute(t *testing.T) { } } +func TestObjectGetDict(t *testing.T) { + fooType := newTestClass("Foo", []*Type{ObjectType}, NewDict()) + foo := newObject(fooType) + if raised := SetAttr(NewRootFrame(), foo, NewStr("bar"), NewInt(123).ToObject()); raised != nil { + panic(raised) + } + fun := wrapFuncForTest(func(f *Frame, o *Object) (*Object, *BaseException) { + return GetAttr(f, o, NewStr("__dict__"), nil) + }) + cases := []invokeTestCase{ + {args: wrapArgs(newObject(ObjectType)), wantExc: mustCreateException(AttributeErrorType, "'object' object has no attribute '__dict__'")}, + {args: wrapArgs(newObject(fooType)), want: NewDict().ToObject()}, + {args: wrapArgs(foo), want: newStringDict(map[string]*Object{"bar": NewInt(123).ToObject()}).ToObject()}, + } + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } +} + +func TestObjectSetDict(t *testing.T) { + fooType := newTestClass("Foo", []*Type{ObjectType}, NewDict()) + testDict := newStringDict(map[string]*Object{"bar": NewInt(123).ToObject()}) + fun := wrapFuncForTest(func(f *Frame, o, val *Object) (*Object, *BaseException) { + if raised := SetAttr(f, o, NewStr("__dict__"), val); raised != nil { + return nil, raised + } + d := o.Dict() + if d == nil { + return None, nil + } + return d.ToObject(), nil + }) + cases := []invokeTestCase{ + {args: wrapArgs(newObject(ObjectType), NewDict()), wantExc: mustCreateException(AttributeErrorType, "'object' object has no attribute '__dict__'")}, + {args: wrapArgs(newObject(fooType), testDict), want: testDict.ToObject()}, + {args: wrapArgs(newObject(fooType), 123), wantExc: mustCreateException(TypeErrorType, "'_set_dict' requires a 'dict' object but received a 'int'")}, + } + for _, cas := range cases { + if err := runInvokeTestCase(fun, &cas); err != "" { + t.Error(err) + } + } +} + func TestObjectNew(t *testing.T) { foo := makeTestType("Foo", ObjectType) foo.flags &= ^typeFlagInstantiable @@ -335,7 +381,7 @@ func TestObjectSetAttr(t *testing.T) { }) setterType := newTestClass("Setter", []*Type{ObjectType}, newStringDict(map[string]*Object{ "__get__": newBuiltinFunction("__get__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { - item, raised := args[1].dict.GetItemString(f, "attr") + item, raised := args[1].Dict().GetItemString(f, "attr") if raised != nil { return nil, raised } @@ -345,7 +391,7 @@ func TestObjectSetAttr(t *testing.T) { return item, nil }).ToObject(), "__set__": newBuiltinFunction("__set__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { - if raised := args[1].dict.SetItemString(f, "attr", NewTuple(args.makeCopy()...).ToObject()); raised != nil { + if raised := args[1].Dict().SetItemString(f, "attr", NewTuple(args.makeCopy()...).ToObject()); raised != nil { return nil, raised } return None, nil diff --git a/runtime/range.go b/runtime/range.go index 26733b12..bae8ab3f 100644 --- a/runtime/range.go +++ b/runtime/range.go @@ -103,7 +103,7 @@ func enumerateNext(f *Frame, o *Object) (ret *Object, raised *BaseException) { raised = f.Raise(StopIterationType.ToObject(), nil, nil) e.index = -1 } else { - ret = NewTuple(NewInt(e.index).ToObject(), item).ToObject() + ret = NewTuple2(NewInt(e.index).ToObject(), item).ToObject() e.index++ } } diff --git a/runtime/seq.go b/runtime/seq.go index 8ac8f08b..1509be28 100644 --- a/runtime/seq.go +++ b/runtime/seq.go @@ -173,6 +173,23 @@ func seqFindFirst(f *Frame, iterable *Object, pred func(*Object) (bool, *BaseExc return false, nil } +func seqFindElem(f *Frame, elems []*Object, o *Object) (int, *BaseException) { + for i, elem := range elems { + eq, raised := Eq(f, elem, o) + if raised != nil { + return -1, raised + } + found, raised := IsTrue(f, eq) + if raised != nil { + return -1, raised + } + if found { + return i, nil + } + } + return -1, nil +} + func seqForEach(f *Frame, iterable *Object, callback func(*Object) *BaseException) *BaseException { iter, raised := Iter(f, iterable) if raised != nil { diff --git a/runtime/set.go b/runtime/set.go index b515b056..54aa2914 100644 --- a/runtime/set.go +++ b/runtime/set.go @@ -81,7 +81,11 @@ func toSetUnsafe(o *Object) *Set { // Add inserts key into s. If key already exists then does nothing. func (s *Set) Add(f *Frame, key *Object) (bool, *BaseException) { - return s.dict.putItem(f, key, None) + origin, raised := s.dict.putItem(f, key, None, true) + if raised != nil { + return false, raised + } + return origin == nil, nil } // Contains returns true if key exists in s. @@ -102,10 +106,7 @@ func (s *Set) ToObject() *Object { // Update inserts all elements in the iterable o into s. func (s *Set) Update(f *Frame, o *Object) *BaseException { raised := seqForEach(f, o, func(key *Object) *BaseException { - if raised := s.dict.SetItem(f, key, None); raised != nil { - return raised - } - return nil + return s.dict.SetItem(f, key, None) }) return raised } diff --git a/runtime/set_test.go b/runtime/set_test.go index 73216899..6b68eb76 100644 --- a/runtime/set_test.go +++ b/runtime/set_test.go @@ -184,10 +184,10 @@ func TestSetIter(t *testing.T) { cases := []invokeTestCase{ {args: wrapArgs(NewSet()), want: NewTuple().ToObject()}, {args: wrapArgs(newTestSet(1, 2, 3)), want: newTestTuple(1, 2, 3).ToObject()}, - {args: wrapArgs(newTestSet("foo", 3.14)), want: newTestTuple(3.14, "foo").ToObject()}, + {args: wrapArgs(newTestSet("foo", 3.14)), want: newTestTuple("foo", 3.14).ToObject()}, {args: wrapArgs(newTestFrozenSet()), want: NewTuple().ToObject()}, {args: wrapArgs(newTestFrozenSet(1, 2, 3)), want: newTestTuple(1, 2, 3).ToObject()}, - {args: wrapArgs(newTestFrozenSet("foo", 3.14)), want: newTestTuple(3.14, "foo").ToObject()}, + {args: wrapArgs(newTestFrozenSet("foo", 3.14)), want: newTestTuple("foo", 3.14).ToObject()}, } for _, cas := range cases { if err := runInvokeTestCase(fun, &cas); err != "" { diff --git a/runtime/slice.go b/runtime/slice.go index be5af2d4..cdb8b293 100644 --- a/runtime/slice.go +++ b/runtime/slice.go @@ -129,14 +129,14 @@ func sliceNew(f *Frame, t *Type, args Args, _ KWArgs) (*Object, *BaseException) func sliceRepr(f *Frame, o *Object) (*Object, *BaseException) { s := toSliceUnsafe(o) - elems := []*Object{None, s.stop, None} + elem0, elem1, elem2 := None, s.stop, None if s.start != nil { - elems[0] = s.start + elem0 = s.start } if s.step != nil { - elems[2] = s.step + elem2 = s.step } - r, raised := Repr(f, NewTuple(elems...).ToObject()) + r, raised := Repr(f, NewTuple3(elem0, elem1, elem2).ToObject()) if raised != nil { return nil, raised } diff --git a/runtime/slots.go b/runtime/slots.go index f168f390..b2dd2426 100644 --- a/runtime/slots.go +++ b/runtime/slots.go @@ -375,22 +375,29 @@ type typeSlots struct { Basis *basisSlot Call *callSlot Cmp *binaryOpSlot + Complex *unaryOpSlot Contains *binaryOpSlot DelAttr *delAttrSlot Delete *deleteSlot DelItem *delItemSlot Div *binaryOpSlot + DivMod *binaryOpSlot Eq *binaryOpSlot Float *unaryOpSlot + FloorDiv *binaryOpSlot GE *binaryOpSlot Get *getSlot GetAttribute *getAttributeSlot GetItem *binaryOpSlot GT *binaryOpSlot Hash *unaryOpSlot + Hex *unaryOpSlot IAdd *binaryOpSlot IAnd *binaryOpSlot IDiv *binaryOpSlot + IDivMod *binaryOpSlot + IFloorDiv *binaryOpSlot + ILShift *binaryOpSlot IMod *binaryOpSlot IMul *binaryOpSlot Index *unaryOpSlot @@ -399,6 +406,7 @@ type typeSlots struct { Invert *unaryOpSlot IOr *binaryOpSlot IPow *binaryOpSlot + IRShift *binaryOpSlot ISub *binaryOpSlot Iter *unaryOpSlot IXor *binaryOpSlot @@ -415,12 +423,16 @@ type typeSlots struct { New *newSlot Next *unaryOpSlot NonZero *unaryOpSlot + Oct *unaryOpSlot Or *binaryOpSlot + Pos *unaryOpSlot Pow *binaryOpSlot RAdd *binaryOpSlot RAnd *binaryOpSlot RDiv *binaryOpSlot + RDivMod *binaryOpSlot Repr *unaryOpSlot + RFloorDiv *binaryOpSlot RLShift *binaryOpSlot RMod *binaryOpSlot RMul *binaryOpSlot diff --git a/runtime/str.go b/runtime/str.go index 401f7935..af191087 100644 --- a/runtime/str.go +++ b/runtime/str.go @@ -35,6 +35,8 @@ var ( strInterpolationRegexp = regexp.MustCompile(`^%([#0 +-]?)((\*|[0-9]+)?)((\.(\*|[0-9]+))?)[hlL]?([diouxXeEfFgGcrs%])`) internedStrs = map[string]*Str{} caseOffset = byte('a' - 'A') + + internedName = NewStr("__name__") ) type stripSide int @@ -175,6 +177,19 @@ func strCapitalize(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) return NewStr(string(b)).ToObject(), nil } +func strCenter(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + s, width, fill, raised := strJustDecodeArgs(f, args, "center") + if raised != nil { + return nil, raised + } + if len(s) >= width { + return NewStr(s).ToObject(), nil + } + marg := width - len(s) + left := marg/2 + (marg & width & 1) + return NewStr(pad(s, left, marg-left, fill)).ToObject(), nil +} + func strContains(f *Frame, o *Object, value *Object) (*Object, *BaseException) { if value.isInstance(UnicodeType) { decoded, raised := toStrUnsafe(o).Decode(f, EncodeDefault, EncodeStrict) @@ -190,6 +205,16 @@ func strContains(f *Frame, o *Object, value *Object) (*Object, *BaseException) { return GetBool(strings.Contains(toStrUnsafe(o).Value(), toStrUnsafe(value).Value())).ToObject(), nil } +func strCount(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "count", args, StrType, ObjectType); raised != nil { + return nil, raised + } + s := toStrUnsafe(args[0]).Value() + sep := toStrUnsafe(args[1]).Value() + cnt := strings.Count(s, sep) + return NewInt(cnt).ToObject(), nil +} + func strDecode(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { // TODO: Accept unicode for encoding and errors args. expectedTypes := []*Type{StrType, StrType, StrType} @@ -226,44 +251,9 @@ func strEq(f *Frame, v, w *Object) (*Object, *BaseException) { // strFind returns the lowest index in s where the substring sub is found such // that sub is wholly contained in s[start:end]. Return -1 on failure. func strFind(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { - var raised *BaseException - // TODO: Support for unicode substring. - expectedTypes := []*Type{StrType, StrType, ObjectType, ObjectType} - argc := len(args) - if argc == 2 || argc == 3 { - expectedTypes = expectedTypes[:argc] - } - if raised := checkMethodArgs(f, "find/index", args, expectedTypes...); raised != nil { - return nil, raised - } - s := toStrUnsafe(args[0]).Value() - l := len(s) - start, end := 0, l - if argc >= 3 && args[2] != None { - start, raised = IndexInt(f, args[2]) - if raised != nil { - return nil, raised - } - } - if argc == 4 && args[3] != None { - end, raised = IndexInt(f, args[3]) - if raised != nil { - return nil, raised - } - } - if start > l { - return NewInt(-1).ToObject(), nil - } - start, end = adjustIndex(start, end, l) - if start > end { - return NewInt(-1).ToObject(), nil - } - sub := toStrUnsafe(args[1]).Value() - index := strings.Index(s[start:end], sub) - if index != -1 { - index += start - } - return NewInt(index).ToObject(), nil + return strFindOrIndex(f, args, func(s, sub string) (int, *BaseException) { + return strings.Index(s, sub), nil + }) } func strGE(f *Frame, v, w *Object) (*Object, *BaseException) { @@ -307,7 +297,7 @@ func strGetNewArgs(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { if raised := checkMethodArgs(f, "__getnewargs__", args, StrType); raised != nil { return nil, raised } - return NewTuple(args[0]).ToObject(), nil + return NewTuple1(args[0]).ToObject(), nil } func strGT(f *Frame, v, w *Object) (*Object, *BaseException) { @@ -325,6 +315,150 @@ func strHash(f *Frame, o *Object) (*Object, *BaseException) { return h.ToObject(), nil } +func strIndex(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + return strFindOrIndex(f, args, func(s, sub string) (i int, raised *BaseException) { + i = strings.Index(s, sub) + if i == -1 { + raised = f.RaiseType(ValueErrorType, "substring not found") + } + return i, raised + }) +} + +func strIsAlNum(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "isalnum", args, StrType); raised != nil { + return nil, raised + } + s := toStrUnsafe(args[0]).Value() + if len(s) == 0 { + return False.ToObject(), nil + } + for i := range s { + if !isAlNum(s[i]) { + return False.ToObject(), nil + } + } + return True.ToObject(), nil +} + +func strIsAlpha(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "isalpha", args, StrType); raised != nil { + return nil, raised + } + s := toStrUnsafe(args[0]).Value() + if len(s) == 0 { + return False.ToObject(), nil + } + for i := range s { + if !isAlpha(s[i]) { + return False.ToObject(), nil + } + } + return True.ToObject(), nil +} + +func strIsDigit(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "isdigit", args, StrType); raised != nil { + return nil, raised + } + s := toStrUnsafe(args[0]).Value() + if len(s) == 0 { + return False.ToObject(), nil + } + for i := range s { + if !isDigit(s[i]) { + return False.ToObject(), nil + } + } + return True.ToObject(), nil +} + +func strIsLower(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "islower", args, StrType); raised != nil { + return nil, raised + } + s := toStrUnsafe(args[0]).Value() + if len(s) == 0 { + return False.ToObject(), nil + } + for i := range s { + if !isLower(s[i]) { + return False.ToObject(), nil + } + } + return True.ToObject(), nil +} + +func strIsSpace(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "isspace", args, StrType); raised != nil { + return nil, raised + } + s := toStrUnsafe(args[0]).Value() + if len(s) == 0 { + return False.ToObject(), nil + } + for i := range s { + if !isSpace(s[i]) { + return False.ToObject(), nil + } + } + return True.ToObject(), nil +} + +func strIsTitle(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "istitle", args, StrType); raised != nil { + return nil, raised + } + + s := toStrUnsafe(args[0]).Value() + if len(s) == 0 { + return False.ToObject(), nil + } + + if len(s) == 1 { + return GetBool(isUpper(s[0])).ToObject(), nil + } + + cased := false + previousIsCased := false + + for i := range s { + if isUpper(s[i]) { + if previousIsCased { + return False.ToObject(), nil + } + previousIsCased = true + cased = true + } else if isLower(s[i]) { + if !previousIsCased { + return False.ToObject(), nil + } + previousIsCased = true + cased = true + } else { + previousIsCased = false + } + } + + return GetBool(cased).ToObject(), nil +} + +func strIsUpper(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "isupper", args, StrType); raised != nil { + return nil, raised + } + s := toStrUnsafe(args[0]).Value() + if len(s) == 0 { + return False.ToObject(), nil + } + for i := range s { + if !isUpper(s[i]) { + return False.ToObject(), nil + } + } + return True.ToObject(), nil +} + func strJoin(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { if raised := checkMethodArgs(f, "join", args, StrType, ObjectType); raised != nil { return nil, raised @@ -383,6 +517,17 @@ func strLen(f *Frame, o *Object) (*Object, *BaseException) { return NewInt(len(toStrUnsafe(o).Value())).ToObject(), nil } +func strLJust(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + s, width, fill, raised := strJustDecodeArgs(f, args, "ljust") + if raised != nil { + return nil, raised + } + if len(s) >= width { + return NewStr(s).ToObject(), nil + } + return NewStr(pad(s, 0, width-len(s), fill)).ToObject(), nil +} + func strLower(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { expectedTypes := []*Type{StrType} if raised := checkMethodArgs(f, "lower", args, expectedTypes...); raised != nil { @@ -416,7 +561,7 @@ func strMod(f *Frame, v, w *Object) (*Object, *BaseException) { case w.isInstance(TupleType): return strInterpolate(f, s, toTupleUnsafe(w)) default: - return strInterpolate(f, s, NewTuple(w)) + return strInterpolate(f, s, NewTuple1(w)) } } @@ -562,6 +707,33 @@ func strRepr(_ *Frame, o *Object) (*Object, *BaseException) { return NewStr(buf.String()).ToObject(), nil } +func strRFind(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + return strFindOrIndex(f, args, func(s, sub string) (int, *BaseException) { + return strings.LastIndex(s, sub), nil + }) +} + +func strRIndex(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + return strFindOrIndex(f, args, func(s, sub string) (i int, raised *BaseException) { + i = strings.LastIndex(s, sub) + if i == -1 { + raised = f.RaiseType(ValueErrorType, "substring not found") + } + return i, raised + }) +} + +func strRJust(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + s, width, fill, raised := strJustDecodeArgs(f, args, "rjust") + if raised != nil { + return nil, raised + } + if len(s) >= width { + return NewStr(s).ToObject(), nil + } + return NewStr(pad(s, width-len(s), 0, fill)).ToObject(), nil +} + func strSplit(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { expectedTypes := []*Type{StrType, ObjectType, IntType} argc := len(args) @@ -738,19 +910,56 @@ func strStr(_ *Frame, o *Object) (*Object, *BaseException) { return NewStr(toStrUnsafe(o).Value()).ToObject(), nil } +func strSwapCase(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "swapcase", args, StrType); raised != nil { + return nil, raised + } + s := toStrUnsafe(args[0]).Value() + numBytes := len(s) + if numBytes == 0 { + return args[0], nil + } + b := make([]byte, numBytes) + for i := 0; i < numBytes; i++ { + if isLower(s[i]) { + b[i] = toUpper(s[i]) + } else if isUpper(s[i]) { + b[i] = toLower(s[i]) + } else { + b[i] = s[i] + } + } + return NewStr(string(b)).ToObject(), nil +} + func initStrType(dict map[string]*Object) { dict["__getnewargs__"] = newBuiltinFunction("__getnewargs__", strGetNewArgs).ToObject() dict["capitalize"] = newBuiltinFunction("capitalize", strCapitalize).ToObject() + dict["count"] = newBuiltinFunction("count", strCount).ToObject() + dict["center"] = newBuiltinFunction("center", strCenter).ToObject() dict["decode"] = newBuiltinFunction("decode", strDecode).ToObject() dict["endswith"] = newBuiltinFunction("endswith", strEndsWith).ToObject() dict["find"] = newBuiltinFunction("find", strFind).ToObject() + dict["index"] = newBuiltinFunction("index", strIndex).ToObject() + dict["isalnum"] = newBuiltinFunction("isalnum", strIsAlNum).ToObject() + dict["isalpha"] = newBuiltinFunction("isalpha", strIsAlpha).ToObject() + dict["isdigit"] = newBuiltinFunction("isdigit", strIsDigit).ToObject() + dict["islower"] = newBuiltinFunction("islower", strIsLower).ToObject() + dict["isspace"] = newBuiltinFunction("isspace", strIsSpace).ToObject() + dict["istitle"] = newBuiltinFunction("istitle", strIsTitle).ToObject() + dict["isupper"] = newBuiltinFunction("isupper", strIsUpper).ToObject() dict["join"] = newBuiltinFunction("join", strJoin).ToObject() dict["lower"] = newBuiltinFunction("lower", strLower).ToObject() + dict["ljust"] = newBuiltinFunction("ljust", strLJust).ToObject() dict["lstrip"] = newBuiltinFunction("lstrip", strLStrip).ToObject() + dict["rfind"] = newBuiltinFunction("rfind", strRFind).ToObject() + dict["rindex"] = newBuiltinFunction("rindex", strRIndex).ToObject() + dict["rjust"] = newBuiltinFunction("rjust", strRJust).ToObject() dict["split"] = newBuiltinFunction("split", strSplit).ToObject() dict["splitlines"] = newBuiltinFunction("splitlines", strSplitLines).ToObject() dict["startswith"] = newBuiltinFunction("startswith", strStartsWith).ToObject() dict["strip"] = newBuiltinFunction("strip", strStrip).ToObject() + dict["swapcase"] = newBuiltinFunction("swapcase", strSwapCase).ToObject() dict["replace"] = newBuiltinFunction("replace", strReplace).ToObject() dict["rstrip"] = newBuiltinFunction("rstrip", strRStrip).ToObject() dict["title"] = newBuiltinFunction("title", strTitle).ToObject() @@ -805,21 +1014,31 @@ func strInterpolate(f *Frame, format string, values *Tuple) (*Object, *BaseExcep if matches == nil { return nil, f.RaiseType(ValueErrorType, "invalid format spec") } - if matches[7] != "%" && valueIndex >= len(values.elems) { + flags, fieldType := matches[1], matches[7] + if fieldType != "%" && valueIndex >= len(values.elems) { return nil, f.RaiseType(TypeErrorType, "not enough arguments for format string") } - if matches[1] != "" { - return nil, f.RaiseType(NotImplementedErrorType, "conversion flags not yet supported") - } - if matches[2] != "" || matches[4] != "" { + fieldWidth := -1 + if matches[2] == "*" || matches[4] != "" { return nil, f.RaiseType(NotImplementedErrorType, "field width not yet supported") } - switch matches[7] { + if matches[2] != "" { + var err error + fieldWidth, err = strconv.Atoi(matches[2]) + if err != nil { + return nil, f.RaiseType(TypeErrorType, fmt.Sprint(err)) + } + } + if flags != "" && flags != "0" { + return nil, f.RaiseType(NotImplementedErrorType, "conversion flags not yet supported") + } + var val string + switch fieldType { case "r", "s": o := values.elems[valueIndex] var s *Str var raised *BaseException - if matches[7] == "r" { + if fieldType == "r" { s, raised = Repr(f, o) } else { s, raised = ToStr(f, o) @@ -827,46 +1046,74 @@ func strInterpolate(f *Frame, format string, values *Tuple) (*Object, *BaseExcep if raised != nil { return nil, raised } - buf.WriteString(s.Value()) + val = s.Value() + if fieldWidth > 0 { + val = strLeftPad(val, fieldWidth, " ") + } + buf.WriteString(val) valueIndex++ case "f": o := values.elems[valueIndex] - if val, ok := floatCoerce(o); ok { - buf.WriteString(strconv.FormatFloat(val, 'f', 6, 64)) + if v, ok := floatCoerce(o); ok { + val := strconv.FormatFloat(v, 'f', 6, 64) + if fieldWidth > 0 { + fillchar := " " + if flags != "" { + fillchar = flags + } + val = strLeftPad(val, fieldWidth, fillchar) + } + buf.WriteString(val) valueIndex++ } else { return nil, f.RaiseType(TypeErrorType, fmt.Sprintf("float argument required, not %s", o.typ.Name())) } - case "d", "x", "X": - var val string + case "d", "x", "X", "o": o := values.elems[valueIndex] i, raised := ToInt(f, values.elems[valueIndex]) if raised != nil { return nil, raised } - if matches[7] == "d" { + if fieldType == "d" { s, raised := ToStr(f, i) if raised != nil { return nil, raised } val = s.Value() + } else if matches[7] == "o" { + if o.isInstance(LongType) { + val = toLongUnsafe(o).Value().Text(8) + } else { + val = strconv.FormatInt(int64(toIntUnsafe(i).Value()), 8) + } } else { if o.isInstance(LongType) { val = toLongUnsafe(o).Value().Text(16) } else { val = strconv.FormatInt(int64(toIntUnsafe(i).Value()), 16) } - if matches[7] == "X" { + if fieldType == "X" { val = strings.ToUpper(val) } } + if fieldWidth > 0 { + fillchar := " " + if flags != "" { + fillchar = flags + } + val = strLeftPad(val, fieldWidth, fillchar) + } buf.WriteString(val) valueIndex++ case "%": - buf.WriteString("%") + val = "%" + if fieldWidth > 0 { + val = strLeftPad(val, fieldWidth, " ") + } + buf.WriteString(val) default: format := "conversion type not yet supported: %s" - return nil, f.RaiseType(NotImplementedErrorType, fmt.Sprintf(format, matches[7])) + return nil, f.RaiseType(NotImplementedErrorType, fmt.Sprintf(format, fieldType)) } format = format[len(matches[0]):] index = strings.Index(format, "%") @@ -996,12 +1243,12 @@ func strTitle(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { for i := 0; i < numBytes; i++ { c := s[i] switch { - case s[i] >= 'a' && s[i] <= 'z': + case isLower(c): if !previousIsCased { c = toUpper(c) } previousIsCased = true - case s[i] >= 'A' && s[i] <= 'Z': + case isUpper(c): if previousIsCased { c = toLower(c) } @@ -1036,24 +1283,11 @@ func strZFill(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { return nil, raised } s := toStrUnsafe(args[0]).Value() - l := len(s) width, raised := ToIntValue(f, args[1]) if raised != nil { return nil, raised } - if width <= l { - return args[0], nil - } - buf := bytes.Buffer{} - buf.Grow(width) - if l > 0 && (s[0] == '-' || s[0] == '+') { - buf.WriteByte(s[0]) - s = s[1:] - width-- - } - buf.WriteString(strings.Repeat("0", width-len(s))) - buf.WriteString(s) - return NewStr(buf.String()).ToObject(), nil + return NewStr(strLeftPad(s, width, "0")).ToObject(), nil } func init() { @@ -1064,15 +1298,152 @@ func init() { } func toLower(b byte) byte { - if b >= 'A' && b <= 'Z' { + if isUpper(b) { return b + caseOffset } return b } func toUpper(b byte) byte { - if b >= 'a' && b <= 'z' { + if isLower(b) { return b - caseOffset } return b } + +func isAlNum(c byte) bool { + return isAlpha(c) || isDigit(c) +} + +func isAlpha(c byte) bool { + return isUpper(c) || isLower(c) +} + +func isDigit(c byte) bool { + return '0' <= c && c <= '9' +} + +func isLower(c byte) bool { + return 'a' <= c && c <= 'z' +} + +func isSpace(c byte) bool { + switch c { + case ' ', '\n', '\t', '\v', '\f', '\r': + return true + default: + return false + } +} + +func isUpper(c byte) bool { + return 'A' <= c && c <= 'Z' +} + +func pad(s string, left int, right int, fillchar string) string { + buf := bytes.Buffer{} + + if left < 0 { + left = 0 + } + + if right < 0 { + right = 0 + } + + if left == 0 && right == 0 { + return s + } + + buf.Grow(left + len(s) + right) + buf.WriteString(strings.Repeat(fillchar, left)) + buf.WriteString(s) + buf.WriteString(strings.Repeat(fillchar, right)) + + return buf.String() +} + +// strLeftPad returns s padded with fillchar so that its length is at least width. +// Fillchar must be a single character. When fillchar is "0", s starting with a +// sign are handled correctly. +func strLeftPad(s string, width int, fillchar string) string { + l := len(s) + if width <= l { + return s + } + buf := bytes.Buffer{} + buf.Grow(width) + if l > 0 && fillchar == "0" && (s[0] == '-' || s[0] == '+') { + buf.WriteByte(s[0]) + s = s[1:] + l = len(s) + width-- + } + // TODO: Support or throw fillchar len more than one. + buf.WriteString(strings.Repeat(fillchar, width-l)) + buf.WriteString(s) + return buf.String() +} + +type indexFunc func(string, string) (int, *BaseException) + +func strFindOrIndex(f *Frame, args Args, fn indexFunc) (*Object, *BaseException) { + // TODO: Support for unicode substring. + expectedTypes := []*Type{StrType, StrType, ObjectType, ObjectType} + argc := len(args) + if argc == 2 || argc == 3 { + expectedTypes = expectedTypes[:argc] + } + if raised := checkMethodArgs(f, "find/index", args, expectedTypes...); raised != nil { + return nil, raised + } + s := toStrUnsafe(args[0]).Value() + l := len(s) + start, end := 0, l + var raised *BaseException + if argc >= 3 && args[2] != None { + start, raised = IndexInt(f, args[2]) + if raised != nil { + return nil, raised + } + } + if argc == 4 && args[3] != None { + end, raised = IndexInt(f, args[3]) + if raised != nil { + return nil, raised + } + } + // Default to an impossible search. + search, sub := "", "-" + if start <= l { + start, end = adjustIndex(start, end, l) + if start <= end { + sub = toStrUnsafe(args[1]).Value() + search = s[start:end] + } + } + index, raised := fn(search, sub) + if raised != nil { + return nil, raised + } + if index != -1 { + index += start + } + return NewInt(index).ToObject(), nil +} + +func strJustDecodeArgs(f *Frame, args Args, name string) (string, int, string, *BaseException) { + expectedTypes := []*Type{StrType, IntType, StrType} + if raised := checkMethodArgs(f, name, args, expectedTypes...); raised != nil { + return "", 0, "", raised + } + s := toStrUnsafe(args[0]).Value() + width := toIntUnsafe(args[1]).Value() + fill := toStrUnsafe(args[2]).Value() + + if numChars := len(fill); numChars != 1 { + return s, width, fill, f.RaiseType(TypeErrorType, fmt.Sprintf("%[1]s() argument 2 must be char, not str", name)) + } + + return s, width, fill, nil +} diff --git a/runtime/str_test.go b/runtime/str_test.go index 5d164320..67c45b78 100644 --- a/runtime/str_test.go +++ b/runtime/str_test.go @@ -73,11 +73,19 @@ func TestStrBinaryOps(t *testing.T) { {args: wrapArgs(Add, "", newObject(ObjectType)), wantExc: mustCreateException(TypeErrorType, "unsupported operand type(s) for +: 'str' and 'object'")}, {args: wrapArgs(Add, None, ""), wantExc: mustCreateException(TypeErrorType, "unsupported operand type(s) for +: 'NoneType' and 'str'")}, {args: wrapArgs(Mod, "%s", 42), want: NewStr("42").ToObject()}, + {args: wrapArgs(Mod, "%3s", 42), want: NewStr(" 42").ToObject()}, + {args: wrapArgs(Mod, "%03s", 42), want: NewStr(" 42").ToObject()}, {args: wrapArgs(Mod, "%f", 3.14), want: NewStr("3.140000").ToObject()}, + {args: wrapArgs(Mod, "%10f", 3.14), want: NewStr(" 3.140000").ToObject()}, + {args: wrapArgs(Mod, "%010f", 3.14), want: NewStr("003.140000").ToObject()}, {args: wrapArgs(Mod, "abc %d", NewLong(big.NewInt(123))), want: NewStr("abc 123").ToObject()}, {args: wrapArgs(Mod, "%d", 3.14), want: NewStr("3").ToObject()}, {args: wrapArgs(Mod, "%%", NewTuple()), want: NewStr("%").ToObject()}, + {args: wrapArgs(Mod, "%3%", NewTuple()), want: NewStr(" %").ToObject()}, + {args: wrapArgs(Mod, "%03%", NewTuple()), want: NewStr(" %").ToObject()}, {args: wrapArgs(Mod, "%r", "abc"), want: NewStr("'abc'").ToObject()}, + {args: wrapArgs(Mod, "%6r", "abc"), want: NewStr(" 'abc'").ToObject()}, + {args: wrapArgs(Mod, "%06r", "abc"), want: NewStr(" 'abc'").ToObject()}, {args: wrapArgs(Mod, "%s %s", true), wantExc: mustCreateException(TypeErrorType, "not enough arguments for format string")}, {args: wrapArgs(Mod, "%Z", None), wantExc: mustCreateException(ValueErrorType, "invalid format spec")}, {args: wrapArgs(Mod, "%s", NewDict()), wantExc: mustCreateException(NotImplementedErrorType, "mappings not yet supported")}, @@ -91,6 +99,13 @@ func TestStrBinaryOps(t *testing.T) { {args: wrapArgs(Mod, "%f", None), wantExc: mustCreateException(TypeErrorType, "float argument required, not NoneType")}, {args: wrapArgs(Mod, "%s", newTestTuple(123, None)), wantExc: mustCreateException(TypeErrorType, "not all arguments converted during string formatting")}, {args: wrapArgs(Mod, "%d", newTestTuple("123")), wantExc: mustCreateException(TypeErrorType, "an integer is required")}, + {args: wrapArgs(Mod, "%o", newTestTuple(123)), want: NewStr("173").ToObject()}, + {args: wrapArgs(Mod, "%o", 8), want: NewStr("10").ToObject()}, + {args: wrapArgs(Mod, "%o", -8), want: NewStr("-10").ToObject()}, + {args: wrapArgs(Mod, "%03o", newTestTuple(123)), want: NewStr("173").ToObject()}, + {args: wrapArgs(Mod, "%04o", newTestTuple(123)), want: NewStr("0173").ToObject()}, + {args: wrapArgs(Mod, "%o", newTestTuple("123")), wantExc: mustCreateException(TypeErrorType, "an integer is required")}, + {args: wrapArgs(Mod, "%o", None), wantExc: mustCreateException(TypeErrorType, "an integer is required")}, {args: wrapArgs(Mul, "", 10), want: NewStr("").ToObject()}, {args: wrapArgs(Mul, "foo", -2), want: NewStr("").ToObject()}, {args: wrapArgs(Mul, "foobar", 0), want: NewStr("").ToObject()}, @@ -291,6 +306,19 @@ func TestStrMethods(t *testing.T) { {"capitalize", wrapArgs("вол"), NewStr("вол").ToObject(), nil}, {"capitalize", wrapArgs("foobar", 123), nil, mustCreateException(TypeErrorType, "'capitalize' of 'str' requires 1 arguments")}, {"capitalize", wrapArgs("ВОЛ"), NewStr("ВОЛ").ToObject(), nil}, + {"center", wrapArgs("foobar", 9, "#"), NewStr("##foobar#").ToObject(), nil}, + {"center", wrapArgs("foobar", 10, "#"), NewStr("##foobar##").ToObject(), nil}, + {"center", wrapArgs("foobar", 3, "#"), NewStr("foobar").ToObject(), nil}, + {"center", wrapArgs("foobar", -1, "#"), NewStr("foobar").ToObject(), nil}, + {"center", wrapArgs("foobar", 10, "##"), nil, mustCreateException(TypeErrorType, "center() argument 2 must be char, not str")}, + {"center", wrapArgs("foobar", 10, ""), nil, mustCreateException(TypeErrorType, "center() argument 2 must be char, not str")}, + {"count", wrapArgs("", "a"), NewInt(0).ToObject(), nil}, + {"count", wrapArgs("five", ""), NewInt(5).ToObject(), nil}, + {"count", wrapArgs("abba", "bb"), NewInt(1).ToObject(), nil}, + {"count", wrapArgs("abbba", "bb"), NewInt(1).ToObject(), nil}, + {"count", wrapArgs("abbbba", "bb"), NewInt(2).ToObject(), nil}, + {"count", wrapArgs("abcdeffdeabcb", "b"), NewInt(3).ToObject(), nil}, + {"count", wrapArgs(""), nil, mustCreateException(TypeErrorType, "'count' of 'str' requires 2 arguments")}, {"endswith", wrapArgs("", ""), True.ToObject(), nil}, {"endswith", wrapArgs("", "", 1), False.ToObject(), nil}, {"endswith", wrapArgs("foobar", "bar"), True.ToObject(), nil}, @@ -334,12 +362,89 @@ func TestStrMethods(t *testing.T) { {"find", wrapArgs("bar", "a", 0, -1), NewInt(1).ToObject(), nil}, {"find", wrapArgs("foo", newTestTuple("barfoo", "oo").ToObject()), nil, mustCreateException(TypeErrorType, "'find/index' requires a 'str' object but received a 'tuple'")}, {"find", wrapArgs("foo", 123), nil, mustCreateException(TypeErrorType, "'find/index' requires a 'str' object but received a 'int'")}, + {"index", wrapArgs("", ""), NewInt(0).ToObject(), nil}, + {"index", wrapArgs("", "", 1), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"index", wrapArgs("", "", -1), NewInt(0).ToObject(), nil}, + {"index", wrapArgs("", "", None, -1), NewInt(0).ToObject(), nil}, + {"index", wrapArgs("foobar", "bar"), NewInt(3).ToObject(), nil}, + {"index", wrapArgs("foobar", "bar", fooType), nil, mustCreateException(TypeErrorType, "slice indices must be integers or None or have an __index__ method")}, + {"index", wrapArgs("foobar", "bar", NewInt(MaxInt)), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"index", wrapArgs("foobar", "bar", None, NewInt(MaxInt)), NewInt(3).ToObject(), nil}, + {"index", wrapArgs("foobar", "bar", newObject(intIndexType)), NewInt(3).ToObject(), nil}, + {"index", wrapArgs("foobar", "bar", None, newObject(intIndexType)), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"index", wrapArgs("foobar", "bar", newObject(longIndexType)), NewInt(3).ToObject(), nil}, + {"index", wrapArgs("foobar", "bar", None, newObject(longIndexType)), nil, mustCreateException(ValueErrorType, "substring not found")}, + //TODO: Support unicode substring. + {"index", wrapArgs("foobar", NewUnicode("bar")), nil, mustCreateException(TypeErrorType, "'find/index' requires a 'str' object but received a 'unicode'")}, + {"index", wrapArgs("foobar", "bar", "baz"), nil, mustCreateException(TypeErrorType, "slice indices must be integers or None or have an __index__ method")}, + {"index", wrapArgs("foobar", "bar", 0, "baz"), nil, mustCreateException(TypeErrorType, "slice indices must be integers or None or have an __index__ method")}, + {"index", wrapArgs("foobar", "bar", None), NewInt(3).ToObject(), nil}, + {"index", wrapArgs("foobar", "bar", 0, None), NewInt(3).ToObject(), nil}, + {"index", wrapArgs("foobar", "bar", 0, -2), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"index", wrapArgs("foobar", "foo", 0, 3), NewInt(0).ToObject(), nil}, + {"index", wrapArgs("foobar", "foo", 10), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"index", wrapArgs("foobar", "foo", 3, 3), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"index", wrapArgs("foobar", "bar", 3, 5), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"index", wrapArgs("foobar", "bar", 5, 3), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"index", wrapArgs("bar", "foobar"), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"index", wrapArgs("bar", "a", 1, 10), NewInt(1).ToObject(), nil}, + {"index", wrapArgs("bar", "a", NewLong(big.NewInt(1)), 10), NewInt(1).ToObject(), nil}, + {"index", wrapArgs("bar", "a", 0, NewLong(big.NewInt(2))), NewInt(1).ToObject(), nil}, + {"index", wrapArgs("bar", "a", 1, 3), NewInt(1).ToObject(), nil}, + {"index", wrapArgs("bar", "a", 0, -1), NewInt(1).ToObject(), nil}, + {"index", wrapArgs("foo", newTestTuple("barfoo", "oo").ToObject()), nil, mustCreateException(TypeErrorType, "'find/index' requires a 'str' object but received a 'tuple'")}, + {"index", wrapArgs("foo", 123), nil, mustCreateException(TypeErrorType, "'find/index' requires a 'str' object but received a 'int'")}, + {"index", wrapArgs("barbaz", "ba"), NewInt(0).ToObject(), nil}, + {"index", wrapArgs("barbaz", "ba", 1), NewInt(3).ToObject(), nil}, + {"isalnum", wrapArgs("123abc"), True.ToObject(), nil}, + {"isalnum", wrapArgs(""), False.ToObject(), nil}, + {"isalnum", wrapArgs("#$%"), False.ToObject(), nil}, + {"isalnum", wrapArgs("abc#123"), False.ToObject(), nil}, + {"isalnum", wrapArgs("123abc", "efg"), nil, mustCreateException(TypeErrorType, "'isalnum' of 'str' requires 1 arguments")}, + {"isalpha", wrapArgs("xyz"), True.ToObject(), nil}, + {"isalpha", wrapArgs(""), False.ToObject(), nil}, + {"isalpha", wrapArgs("#$%"), False.ToObject(), nil}, + {"isalpha", wrapArgs("abc#123"), False.ToObject(), nil}, + {"isalpha", wrapArgs("absd", "efg"), nil, mustCreateException(TypeErrorType, "'isalpha' of 'str' requires 1 arguments")}, + {"isdigit", wrapArgs("abc"), False.ToObject(), nil}, + {"isdigit", wrapArgs("123"), True.ToObject(), nil}, + {"isdigit", wrapArgs(""), False.ToObject(), nil}, + {"isdigit", wrapArgs("abc#123"), False.ToObject(), nil}, + {"isdigit", wrapArgs("123", "456"), nil, mustCreateException(TypeErrorType, "'isdigit' of 'str' requires 1 arguments")}, + {"islower", wrapArgs("abc"), True.ToObject(), nil}, + {"islower", wrapArgs("ABC"), False.ToObject(), nil}, + {"islower", wrapArgs(""), False.ToObject(), nil}, + {"islower", wrapArgs("abc#123"), False.ToObject(), nil}, + {"islower", wrapArgs("123", "456"), nil, mustCreateException(TypeErrorType, "'islower' of 'str' requires 1 arguments")}, + {"isupper", wrapArgs("abc"), False.ToObject(), nil}, + {"isupper", wrapArgs("ABC"), True.ToObject(), nil}, + {"isupper", wrapArgs(""), False.ToObject(), nil}, + {"isupper", wrapArgs("abc#123"), False.ToObject(), nil}, + {"isupper", wrapArgs("123", "456"), nil, mustCreateException(TypeErrorType, "'isupper' of 'str' requires 1 arguments")}, + {"isspace", wrapArgs(""), False.ToObject(), nil}, + {"isspace", wrapArgs(" "), True.ToObject(), nil}, + {"isspace", wrapArgs("\n\t\v\f\r "), True.ToObject(), nil}, + {"isspace", wrapArgs(""), False.ToObject(), nil}, + {"isspace", wrapArgs("asdad"), False.ToObject(), nil}, + {"isspace", wrapArgs(" "), True.ToObject(), nil}, + {"isspace", wrapArgs(" ", "456"), nil, mustCreateException(TypeErrorType, "'isspace' of 'str' requires 1 arguments")}, + {"istitle", wrapArgs("abc"), False.ToObject(), nil}, + {"istitle", wrapArgs("Abc&D"), True.ToObject(), nil}, + {"istitle", wrapArgs("ABc&D"), False.ToObject(), nil}, + {"istitle", wrapArgs(""), False.ToObject(), nil}, + {"istitle", wrapArgs("abc#123"), False.ToObject(), nil}, + {"istitle", wrapArgs("ABc&D", "456"), nil, mustCreateException(TypeErrorType, "'istitle' of 'str' requires 1 arguments")}, {"join", wrapArgs(",", newTestList("foo", "bar")), NewStr("foo,bar").ToObject(), nil}, {"join", wrapArgs(":", newTestList("foo", "bar", NewUnicode("baz"))), NewUnicode("foo:bar:baz").ToObject(), nil}, {"join", wrapArgs("nope", NewTuple()), NewStr("").ToObject(), nil}, {"join", wrapArgs("nope", newTestTuple("foo")), NewStr("foo").ToObject(), nil}, {"join", wrapArgs(",", newTestList("foo", "bar", 3.14)), nil, mustCreateException(TypeErrorType, "sequence item 2: expected string, float found")}, {"join", wrapArgs("\xff", newTestList(NewUnicode("foo"), NewUnicode("bar"))), nil, mustCreateException(UnicodeDecodeErrorType, "'utf8' codec can't decode byte 0xff in position 0")}, + {"ljust", wrapArgs("foobar", 10, "#"), NewStr("foobar####").ToObject(), nil}, + {"ljust", wrapArgs("foobar", 3, "#"), NewStr("foobar").ToObject(), nil}, + {"ljust", wrapArgs("foobar", -1, "#"), NewStr("foobar").ToObject(), nil}, + {"ljust", wrapArgs("foobar", 10, "##"), nil, mustCreateException(TypeErrorType, "ljust() argument 2 must be char, not str")}, + {"ljust", wrapArgs("foobar", 10, ""), nil, mustCreateException(TypeErrorType, "ljust() argument 2 must be char, not str")}, {"lower", wrapArgs(""), NewStr("").ToObject(), nil}, {"lower", wrapArgs("a"), NewStr("a").ToObject(), nil}, {"lower", wrapArgs("A"), NewStr("a").ToObject(), nil}, @@ -362,6 +467,79 @@ func TestStrMethods(t *testing.T) { {"lstrip", wrapArgs("foo", "bar", "baz"), nil, mustCreateException(TypeErrorType, "'strip' of 'str' requires 2 arguments")}, {"lstrip", wrapArgs("\xfboo", NewUnicode("o")), nil, mustCreateException(UnicodeDecodeErrorType, "'utf8' codec can't decode byte 0xfb in position 0")}, {"lstrip", wrapArgs("foo", NewUnicode("o")), NewUnicode("f").ToObject(), nil}, + {"rfind", wrapArgs("", ""), NewInt(0).ToObject(), nil}, + {"rfind", wrapArgs("", "", 1), NewInt(-1).ToObject(), nil}, + {"rfind", wrapArgs("", "", -1), NewInt(0).ToObject(), nil}, + {"rfind", wrapArgs("", "", None, -1), NewInt(0).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "bar"), NewInt(3).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "bar", fooType), nil, mustCreateException(TypeErrorType, "slice indices must be integers or None or have an __index__ method")}, + {"rfind", wrapArgs("foobar", "bar", NewInt(MaxInt)), NewInt(-1).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "bar", None, NewInt(MaxInt)), NewInt(3).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "bar", newObject(intIndexType)), NewInt(3).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "bar", None, newObject(intIndexType)), NewInt(-1).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "bar", newObject(longIndexType)), NewInt(3).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "bar", None, newObject(longIndexType)), NewInt(-1).ToObject(), nil}, + //r TODO: Support unicode substring. + {"rfind", wrapArgs("foobar", NewUnicode("bar")), nil, mustCreateException(TypeErrorType, "'find/index' requires a 'str' object but received a 'unicode'")}, + {"rfind", wrapArgs("foobar", "bar", "baz"), nil, mustCreateException(TypeErrorType, "slice indices must be integers or None or have an __index__ method")}, + {"rfind", wrapArgs("foobar", "bar", 0, "baz"), nil, mustCreateException(TypeErrorType, "slice indices must be integers or None or have an __index__ method")}, + {"rfind", wrapArgs("foobar", "bar", None), NewInt(3).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "bar", 0, None), NewInt(3).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "bar", 0, -2), NewInt(-1).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "foo", 0, 3), NewInt(0).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "foo", 10), NewInt(-1).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "foo", 3, 3), NewInt(-1).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "bar", 3, 5), NewInt(-1).ToObject(), nil}, + {"rfind", wrapArgs("foobar", "bar", 5, 3), NewInt(-1).ToObject(), nil}, + {"rfind", wrapArgs("bar", "foobar"), NewInt(-1).ToObject(), nil}, + {"rfind", wrapArgs("bar", "a", 1, 10), NewInt(1).ToObject(), nil}, + {"rfind", wrapArgs("bar", "a", NewLong(big.NewInt(1)), 10), NewInt(1).ToObject(), nil}, + {"rfind", wrapArgs("bar", "a", 0, NewLong(big.NewInt(2))), NewInt(1).ToObject(), nil}, + {"rfind", wrapArgs("bar", "a", 1, 3), NewInt(1).ToObject(), nil}, + {"rfind", wrapArgs("bar", "a", 0, -1), NewInt(1).ToObject(), nil}, + {"rfind", wrapArgs("foo", newTestTuple("barfoo", "oo").ToObject()), nil, mustCreateException(TypeErrorType, "'find/index' requires a 'str' object but received a 'tuple'")}, + {"rfind", wrapArgs("foo", 123), nil, mustCreateException(TypeErrorType, "'find/index' requires a 'str' object but received a 'int'")}, + {"rfind", wrapArgs("barbaz", "ba"), NewInt(3).ToObject(), nil}, + {"rfind", wrapArgs("barbaz", "ba", None, 4), NewInt(0).ToObject(), nil}, + {"rindex", wrapArgs("", ""), NewInt(0).ToObject(), nil}, + {"rindex", wrapArgs("", "", 1), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"rindex", wrapArgs("", "", -1), NewInt(0).ToObject(), nil}, + {"rindex", wrapArgs("", "", None, -1), NewInt(0).ToObject(), nil}, + {"rindex", wrapArgs("foobar", "bar"), NewInt(3).ToObject(), nil}, + {"rindex", wrapArgs("foobar", "bar", fooType), nil, mustCreateException(TypeErrorType, "slice indices must be integers or None or have an __index__ method")}, + {"rindex", wrapArgs("foobar", "bar", NewInt(MaxInt)), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"rindex", wrapArgs("foobar", "bar", None, NewInt(MaxInt)), NewInt(3).ToObject(), nil}, + {"rindex", wrapArgs("foobar", "bar", newObject(intIndexType)), NewInt(3).ToObject(), nil}, + {"rindex", wrapArgs("foobar", "bar", None, newObject(intIndexType)), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"rindex", wrapArgs("foobar", "bar", newObject(longIndexType)), NewInt(3).ToObject(), nil}, + {"rindex", wrapArgs("foobar", "bar", None, newObject(longIndexType)), nil, mustCreateException(ValueErrorType, "substring not found")}, + // TODO: Support unicode substring. + {"rindex", wrapArgs("foobar", NewUnicode("bar")), nil, mustCreateException(TypeErrorType, "'find/index' requires a 'str' object but received a 'unicode'")}, + {"rindex", wrapArgs("foobar", "bar", "baz"), nil, mustCreateException(TypeErrorType, "slice indices must be integers or None or have an __index__ method")}, + {"rindex", wrapArgs("foobar", "bar", 0, "baz"), nil, mustCreateException(TypeErrorType, "slice indices must be integers or None or have an __index__ method")}, + {"rindex", wrapArgs("foobar", "bar", None), NewInt(3).ToObject(), nil}, + {"rindex", wrapArgs("foobar", "bar", 0, None), NewInt(3).ToObject(), nil}, + {"rindex", wrapArgs("foobar", "bar", 0, -2), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"rindex", wrapArgs("foobar", "foo", 0, 3), NewInt(0).ToObject(), nil}, + {"rindex", wrapArgs("foobar", "foo", 10), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"rindex", wrapArgs("foobar", "foo", 3, 3), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"rindex", wrapArgs("foobar", "bar", 3, 5), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"rindex", wrapArgs("foobar", "bar", 5, 3), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"rindex", wrapArgs("bar", "foobar"), nil, mustCreateException(ValueErrorType, "substring not found")}, + {"rindex", wrapArgs("bar", "a", 1, 10), NewInt(1).ToObject(), nil}, + {"rindex", wrapArgs("bar", "a", NewLong(big.NewInt(1)), 10), NewInt(1).ToObject(), nil}, + {"rindex", wrapArgs("bar", "a", 0, NewLong(big.NewInt(2))), NewInt(1).ToObject(), nil}, + {"rindex", wrapArgs("bar", "a", 1, 3), NewInt(1).ToObject(), nil}, + {"rindex", wrapArgs("bar", "a", 0, -1), NewInt(1).ToObject(), nil}, + {"rindex", wrapArgs("foo", newTestTuple("barfoo", "oo").ToObject()), nil, mustCreateException(TypeErrorType, "'find/index' requires a 'str' object but received a 'tuple'")}, + {"rindex", wrapArgs("foo", 123), nil, mustCreateException(TypeErrorType, "'find/index' requires a 'str' object but received a 'int'")}, + {"rindex", wrapArgs("barbaz", "ba"), NewInt(3).ToObject(), nil}, + {"rindex", wrapArgs("barbaz", "ba", None, 4), NewInt(0).ToObject(), nil}, + {"rjust", wrapArgs("foobar", 10, "#"), NewStr("####foobar").ToObject(), nil}, + {"rjust", wrapArgs("foobar", 3, "#"), NewStr("foobar").ToObject(), nil}, + {"rjust", wrapArgs("foobar", -1, "#"), NewStr("foobar").ToObject(), nil}, + {"rjust", wrapArgs("foobar", 10, "##"), nil, mustCreateException(TypeErrorType, "rjust() argument 2 must be char, not str")}, + {"rjust", wrapArgs("foobar", 10, ""), nil, mustCreateException(TypeErrorType, "rjust() argument 2 must be char, not str")}, {"split", wrapArgs("foo,bar", ","), newTestList("foo", "bar").ToObject(), nil}, {"split", wrapArgs("1,2,3", ",", 1), newTestList("1", "2,3").ToObject(), nil}, {"split", wrapArgs("a \tb\nc"), newTestList("a", "b", "c").ToObject(), nil}, @@ -499,6 +677,17 @@ func TestStrMethods(t *testing.T) { {"zfill", wrapArgs("", False), NewStr("").ToObject(), nil}, {"zfill", wrapArgs("34", NewStr("test")), nil, mustCreateException(TypeErrorType, "an integer is required")}, {"zfill", wrapArgs("34"), nil, mustCreateException(TypeErrorType, "'zfill' of 'str' requires 2 arguments")}, + {"swapcase", wrapArgs(""), NewStr("").ToObject(), nil}, + {"swapcase", wrapArgs("a"), NewStr("A").ToObject(), nil}, + {"swapcase", wrapArgs("A"), NewStr("a").ToObject(), nil}, + {"swapcase", wrapArgs(" A"), NewStr(" a").ToObject(), nil}, + {"swapcase", wrapArgs("abc"), NewStr("ABC").ToObject(), nil}, + {"swapcase", wrapArgs("ABC"), NewStr("abc").ToObject(), nil}, + {"swapcase", wrapArgs("aBC"), NewStr("Abc").ToObject(), nil}, + {"swapcase", wrapArgs("abc def", 123), nil, mustCreateException(TypeErrorType, "'swapcase' of 'str' requires 1 arguments")}, + {"swapcase", wrapArgs(123), nil, mustCreateException(TypeErrorType, "unbound method swapcase() must be called with str instance as first argument (got int instance instead)")}, + {"swapcase", wrapArgs("вол"), NewStr("вол").ToObject(), nil}, + {"swapcase", wrapArgs("ВОЛ"), NewStr("ВОЛ").ToObject(), nil}, } for _, cas := range cases { testCase := invokeTestCase{args: cas.args, want: cas.want, wantExc: cas.wantExc} diff --git a/runtime/super.go b/runtime/super.go index f8b27509..9509d00f 100644 --- a/runtime/super.go +++ b/runtime/super.go @@ -74,7 +74,7 @@ func superGetAttribute(f *Frame, o *Object, name *Str) (*Object, *BaseException) } // Now do normal mro lookup from the successor type. for ; i < n; i++ { - dict := mro[i].dict + dict := mro[i].Dict() res, raised := dict.GetItem(f, name.ToObject()) if raised != nil { return nil, raised diff --git a/runtime/threading.go b/runtime/threading.go index 7ee2c098..ad59815e 100644 --- a/runtime/threading.go +++ b/runtime/threading.go @@ -84,3 +84,36 @@ func (m *recursiveMutex) Unlock(f *Frame) { m.mutex.Unlock() } } + +// TryableMutex is a mutex-like object that also supports TryLock(). +type TryableMutex struct { + c chan bool +} + +// NewTryableMutex returns a new TryableMutex. +func NewTryableMutex() *TryableMutex { + m := &TryableMutex{make(chan bool, 1)} + m.Unlock() + return m +} + +// Lock blocks until the mutex is available and then acquires a lock. +func (m *TryableMutex) Lock() { + <-m.c +} + +// TryLock returns true and acquires a lock if the mutex is available, otherwise +// it returns false. +func (m *TryableMutex) TryLock() bool { + select { + case <-m.c: + return true + default: + return false + } +} + +// Unlock releases the mutex's lock. +func (m *TryableMutex) Unlock() { + m.c <- true +} diff --git a/runtime/tuple.go b/runtime/tuple.go index 79b58dba..a0ddfd7e 100644 --- a/runtime/tuple.go +++ b/runtime/tuple.go @@ -35,6 +35,105 @@ func NewTuple(elems ...*Object) *Tuple { return &Tuple{Object: Object{typ: TupleType}, elems: elems} } +// Below are direct allocation versions of small Tuples. Rather than performing +// two allocations, one for the tuple object and one for the slice holding the +// elements, we allocate both objects at the same time in one block of memory. +// This both decreases the number of allocations overall as well as increases +// memory locality for tuple data. Both of which *should* improve time to +// allocate as well as read performance. The methods below are used by the +// compiler to create fixed size tuples when the size is known ahead of time. +// +// The number of specializations below were chosen first to cover all the fixed +// size tuple allocations in the runtime (currently 5), then filled out to +// cover the whole memory size class (see golang/src/runtime/sizeclasses.go for +// the table). On a 64bit system, a tuple of length 6 occupies 96 bytes - 48 +// bytes for the tuple object and 6*8 (48) bytes of pointers. +// +// If methods are added or removed, then the constant MAX_DIRECT_TUPLE in +// compiler/util.py needs to be updated as well. + +// NewTuple0 returns the empty tuple. This is mostly provided for the +// convenience of the compiler. +func NewTuple0() *Tuple { return emptyTuple } + +// NewTuple1 returns a tuple of length 1 containing just elem0. +func NewTuple1(elem0 *Object) *Tuple { + t := struct { + tuple Tuple + elems [1]*Object + }{ + tuple: Tuple{Object: Object{typ: TupleType}}, + elems: [1]*Object{elem0}, + } + t.tuple.elems = t.elems[:] + return &t.tuple +} + +// NewTuple2 returns a tuple of length 2 containing just elem0 and elem1. +func NewTuple2(elem0, elem1 *Object) *Tuple { + t := struct { + tuple Tuple + elems [2]*Object + }{ + tuple: Tuple{Object: Object{typ: TupleType}}, + elems: [2]*Object{elem0, elem1}, + } + t.tuple.elems = t.elems[:] + return &t.tuple +} + +// NewTuple3 returns a tuple of length 3 containing elem0 to elem2. +func NewTuple3(elem0, elem1, elem2 *Object) *Tuple { + t := struct { + tuple Tuple + elems [3]*Object + }{ + tuple: Tuple{Object: Object{typ: TupleType}}, + elems: [3]*Object{elem0, elem1, elem2}, + } + t.tuple.elems = t.elems[:] + return &t.tuple +} + +// NewTuple4 returns a tuple of length 4 containing elem0 to elem3. +func NewTuple4(elem0, elem1, elem2, elem3 *Object) *Tuple { + t := struct { + tuple Tuple + elems [4]*Object + }{ + tuple: Tuple{Object: Object{typ: TupleType}}, + elems: [4]*Object{elem0, elem1, elem2, elem3}, + } + t.tuple.elems = t.elems[:] + return &t.tuple +} + +// NewTuple5 returns a tuple of length 5 containing elem0 to elem4. +func NewTuple5(elem0, elem1, elem2, elem3, elem4 *Object) *Tuple { + t := struct { + tuple Tuple + elems [5]*Object + }{ + tuple: Tuple{Object: Object{typ: TupleType}}, + elems: [5]*Object{elem0, elem1, elem2, elem3, elem4}, + } + t.tuple.elems = t.elems[:] + return &t.tuple +} + +// NewTuple6 returns a tuple of length 6 containing elem0 to elem5. +func NewTuple6(elem0, elem1, elem2, elem3, elem4, elem5 *Object) *Tuple { + t := struct { + tuple Tuple + elems [6]*Object + }{ + tuple: Tuple{Object: Object{typ: TupleType}}, + elems: [6]*Object{elem0, elem1, elem2, elem3, elem4, elem5}, + } + t.tuple.elems = t.elems[:] + return &t.tuple +} + func toTupleUnsafe(o *Object) *Tuple { return (*Tuple)(o.toPointer()) } @@ -75,6 +174,13 @@ func tupleContains(f *Frame, t, v *Object) (*Object, *BaseException) { return seqContains(f, t, v) } +func tupleCount(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { + if raised := checkMethodArgs(f, "count", args, TupleType, ObjectType); raised != nil { + return nil, raised + } + return seqCount(f, args[0], args[1]) +} + func tupleEq(f *Frame, v, w *Object) (*Object, *BaseException) { return tupleCompare(f, toTupleUnsafe(v), w, Eq) } @@ -99,7 +205,7 @@ func tupleGetNewArgs(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { if raised := checkMethodArgs(f, "__getnewargs__", args, TupleType); raised != nil { return nil, raised } - return NewTuple(args[0]).ToObject(), nil + return NewTuple1(args[0]).ToObject(), nil } func tupleGT(f *Frame, v, w *Object) (*Object, *BaseException) { @@ -181,6 +287,7 @@ func tupleRMul(f *Frame, v, w *Object) (*Object, *BaseException) { } func initTupleType(dict map[string]*Object) { + dict["count"] = newBuiltinFunction("count", tupleCount).ToObject() dict["__getnewargs__"] = newBuiltinFunction("__getnewargs__", tupleGetNewArgs).ToObject() TupleType.slots.Add = &binaryOpSlot{tupleAdd} TupleType.slots.Contains = &binaryOpSlot{tupleContains} diff --git a/runtime/tuple_test.go b/runtime/tuple_test.go index 41be2d3d..be18d42f 100644 --- a/runtime/tuple_test.go +++ b/runtime/tuple_test.go @@ -102,6 +102,19 @@ func TestTupleContains(t *testing.T) { } } +func TestTupleCount(t *testing.T) { + cases := []invokeTestCase{ + {args: wrapArgs(NewTuple(), NewInt(1)), want: NewInt(0).ToObject()}, + {args: wrapArgs(NewTuple(None, None, None), None), want: NewInt(3).ToObject()}, + {args: wrapArgs(NewTuple()), wantExc: mustCreateException(TypeErrorType, "'count' of 'tuple' requires 2 arguments")}, + } + for _, cas := range cases { + if err := runInvokeMethodTestCase(TupleType, "count", &cas); err != "" { + t.Error(err) + } + } +} + func BenchmarkTupleContains(b *testing.B) { b.Run("false-3", func(b *testing.B) { t := newTestTuple("foo", 42, "bar").ToObject() diff --git a/runtime/type.go b/runtime/type.go index e1071068..b5ac5a15 100644 --- a/runtime/type.go +++ b/runtime/type.go @@ -153,7 +153,11 @@ func prepareBuiltinType(typ *Type, init builtinTypeInit) { for i := 0; i < numFields; i++ { field := basis.Field(i) if attr := field.Tag.Get("attr"); attr != "" { - dict[attr] = makeStructFieldDescriptor(typ, field.Name, attr) + fieldMode := fieldDescriptorRO + if mode := field.Tag.Get("attr_mode"); mode == "rw" { + fieldMode = fieldDescriptorRW + } + dict[attr] = makeStructFieldDescriptor(typ, field.Name, attr, fieldMode) } } } @@ -168,7 +172,7 @@ func prepareBuiltinType(typ *Type, init builtinTypeInit) { } } } - typ.dict = newStringDict(dict) + typ.setDict(newStringDict(dict)) if err := prepareType(typ); err != "" { logFatal(err) } @@ -283,7 +287,7 @@ func (t *Type) Name() string { // FullName returns t's fully qualified name including the module. func (t *Type) FullName(f *Frame) (string, *BaseException) { - moduleAttr, raised := t.dict.GetItemString(f, "__module__") + moduleAttr, raised := t.Dict().GetItemString(f, "__module__") if raised != nil { return "", raised } @@ -309,7 +313,7 @@ func (t *Type) isSubclass(super *Type) bool { func (t *Type) mroLookup(f *Frame, name *Str) (*Object, *BaseException) { for _, t := range t.mro { - v, raised := t.dict.GetItem(f, name.ToObject()) + v, raised := t.Dict().GetItem(f, name.ToObject()) if v != nil || raised != nil { return v, raised } diff --git a/runtime/type_test.go b/runtime/type_test.go index ad968ee6..8eeedf9c 100644 --- a/runtime/type_test.go +++ b/runtime/type_test.go @@ -26,7 +26,7 @@ func TestNewClass(t *testing.T) { strBasisStructFunc := func(o *Object) *strBasisStruct { return (*strBasisStruct)(o.toPointer()) } fooType := newBasisType("Foo", reflect.TypeOf(strBasisStruct{}), strBasisStructFunc, StrType) defer delete(basisTypes, fooType.basis) - fooType.dict = NewDict() + fooType.setDict(NewDict()) prepareType(fooType) cases := []struct { wantBasis reflect.Type @@ -66,7 +66,7 @@ func TestNewBasisType(t *testing.T) { if typ.Type() != TypeType { t.Errorf("got %q, want a type", typ.Type().Name()) } - if typ.dict != nil { + if typ.Dict() != nil { t.Error("type's dict was expected to be nil") } wantBases := []*Type{ObjectType} @@ -151,7 +151,7 @@ func TestPrepareType(t *testing.T) { for _, cas := range cases { typ := newBasisType("Foo", cas.basis, cas.basisFunc, cas.base) defer delete(basisTypes, cas.basis) - typ.dict = NewDict() + typ.setDict(NewDict()) prepareType(typ) cas.wantMro[0] = typ if !reflect.DeepEqual(typ.mro, cas.wantMro) { @@ -334,7 +334,7 @@ func TestTypeGetAttribute(t *testing.T) { // __metaclass__ = BarMeta // bar = Bar() barType := &Type{Object: Object{typ: barMetaType}, name: "Bar", basis: fooType.basis, bases: []*Type{fooType}} - barType.dict = newTestDict("bar", "Bar's bar", "foo", 101, "barsetter", setter, "barmetasetter", "NOT setter") + barType.setDict(newTestDict("bar", "Bar's bar", "foo", 101, "barsetter", setter, "barmetasetter", "NOT setter")) bar := newObject(barType) prepareType(barType) cases := []invokeTestCase{ @@ -359,7 +359,7 @@ func TestTypeGetAttribute(t *testing.T) { func TestTypeName(t *testing.T) { fooType := newTestClass("Foo", []*Type{ObjectType}, NewDict()) fun := wrapFuncForTest(func(f *Frame, t *Type) (*Object, *BaseException) { - return GetAttr(f, t.ToObject(), NewStr("__name__"), nil) + return GetAttr(f, t.ToObject(), internedName, nil) }) cas := invokeTestCase{args: wrapArgs(fooType), want: NewStr("Foo").ToObject()} if err := runInvokeTestCase(fun, &cas); err != "" { diff --git a/runtime/unicode.go b/runtime/unicode.go index 596f29b0..7cf1e016 100644 --- a/runtime/unicode.go +++ b/runtime/unicode.go @@ -190,7 +190,7 @@ func unicodeGetNewArgs(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) if raised := checkMethodArgs(f, "__getnewargs__", args, UnicodeType); raised != nil { return nil, raised } - return NewTuple(args[0]).ToObject(), nil + return NewTuple1(args[0]).ToObject(), nil } func unicodeGT(f *Frame, v, w *Object) (*Object, *BaseException) { diff --git a/runtime/weakref.go b/runtime/weakref.go index 16045be4..fc8f3e11 100644 --- a/runtime/weakref.go +++ b/runtime/weakref.go @@ -16,7 +16,6 @@ package grumpy import ( "fmt" - "os" "reflect" "runtime" "sync" @@ -190,11 +189,7 @@ func weakRefFinalizeReferent(o *Object) { for i := numCallbacks - 1; i >= 0; i-- { f := NewRootFrame() if _, raised := callbacks[i].Call(f, Args{r.ToObject()}, nil); raised != nil { - s, raised := FormatException(f, raised) - if raised != nil { - s = raised.String() - } - fmt.Fprint(os.Stderr, s) + Stderr.writeString(FormatExc(f)) } } } diff --git a/testing/builtin_test.py b/testing/builtin_test.py index 91f85ac1..c7932682 100644 --- a/testing/builtin_test.py +++ b/testing/builtin_test.py @@ -28,6 +28,16 @@ assert abs(-3.4) == 3.4 assert isinstance(abs(-3.4), float) +assert abs(complex(0, 0)) == 0.0 +assert abs(complex(3, 4)) == 5.0 +assert abs(-complex(3, 4)) == 5.0 +assert abs(complex(0.123456e-3, 0)) == 0.000123456 +assert abs(complex(0.123456e-3, 3.14151692e+7)) == 31415169.2 +assert isinstance(abs(complex(3, 4)), float) +assert repr(abs(complex(-float('inf'), 1.2))) == 'inf' +assert repr(abs(complex(float('nan'), float('inf')))) == 'inf' +assert repr(abs(complex(3.14, float('nan')))) == 'nan' + try: abs('a') except TypeError as e: @@ -348,3 +358,49 @@ class Foo(object): assert map(None, a) == a assert map(None, a) is not a assert map(None, (1, 2, 3)) == [1, 2, 3] + +# divmod(v, w) + +import sys + +assert divmod(12, 7) == (1, 5) +assert divmod(-12, 7) == (-2, 2) +assert divmod(12, -7) == (-2, -2) +assert divmod(-12, -7) == (1, -5) +assert divmod(-sys.maxsize - 1, -1) == (sys.maxsize + 1, 0) +assert isinstance(divmod(12, 7), tuple) +assert isinstance(divmod(12, 7)[0], int) +assert isinstance(divmod(12, 7)[1], int) + +assert divmod(long(7), long(3)) == (2L, 1L) +assert divmod(long(3), long(-7)) == (-1L, -4L) +assert divmod(long(sys.maxsize), long(-sys.maxsize)) == (-1L, 0L) +assert divmod(long(-sys.maxsize), long(1)) == (-sys.maxsize, 0L) +assert divmod(long(-sys.maxsize), long(-1)) == (sys.maxsize, 0L) +assert isinstance(divmod(long(7), long(3)), tuple) +assert isinstance(divmod(long(7), long(3))[0], long) +assert isinstance(divmod(long(7), long(3))[1], long) + +assert divmod(3.25, 1.0) == (3.0, 0.25) +assert divmod(-3.25, 1.0) == (-4.0, 0.75) +assert divmod(3.25, -1.0) == (-4.0, -0.75) +assert divmod(-3.25, -1.0) == (3.0, -0.25) +assert isinstance(divmod(3.25, 1.0), tuple) +assert isinstance(divmod(3.25, 1.0)[0], float) +assert isinstance(divmod(3.25, 1.0)[1], float) + +try: + divmod('a', 'b') +except TypeError as e: + assert str(e) == "unsupported operand type(s) for divmod(): 'str' and 'str'" +else: + assert AssertionError + +# Check for a bug where zip() and map() were not properly cleaning their +# internal exception state. See: +# https://github.com/google/grumpy/issues/305 +sys.exc_clear() +zip((1, 3), (2, 4)) +assert not any(sys.exc_info()) +map(int, (1, 2, 3)) +assert not any(sys.exc_info()) diff --git a/testing/complex_test.py b/testing/complex_test.py new file mode 100644 index 00000000..715a47c8 --- /dev/null +++ b/testing/complex_test.py @@ -0,0 +1,91 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +assert repr(1j) == "1j" +assert repr(complex()) == "0j" +assert repr(complex('nan-nanj')) == '(nan+nanj)' +assert repr(complex('-Nan+NaNj')) == '(nan+nanj)' +assert repr(complex('inf-infj')) == '(inf-infj)' +assert repr(complex('+inf+infj')) == '(inf+infj)' +assert repr(complex('-infINIty+infinityj')) == '(-inf+infj)' + +assert complex(1.8456e3) == (1845.6+0j) +assert complex('1.8456e3') == (1845.6+0j) +assert complex(0, -365.12) == -365.12j +assert complex('-365.12j') == -365.12j +assert complex(-1.23E2, -45.678e1) == (-123-456.78j) +assert complex('-1.23e2-45.678e1j') == (-123-456.78j) +assert complex(21.98, -1) == (21.98-1j) +assert complex('21.98-j') == (21.98-1j) +assert complex('-j') == -1j +assert complex('+j') == 1j +assert complex('j') == 1j +assert complex(' \t \n \r ( \t \n \r 2.1-3.4j \t \n \r ) \t \n \r ') == (2.1-3.4j) +assert complex(complex(complex(3.14))) == (3.14+0j) +assert complex(complex(1, -2), .151692) == (1-1.848308j) +assert complex(complex(3.14), complex(-0.151692)) == (3.14-0.151692j) +assert complex(complex(-1, 2), complex(3, -4)) == (3+5j) + +try: + complex('((2.1-3.4j))') +except ValueError as e: + assert str(e) == "complex() arg is a malformed string" +else: + raise AssertionError('this was supposed to raise an exception') + +try: + complex('3.14 - 15.16 j') +except ValueError as e: + assert str(e) == "complex() arg is a malformed string" +else: + raise AssertionError('this was supposed to raise an exception') + +try: + complex('foo') +except ValueError as e: + assert str(e) == "complex() arg is a malformed string" +else: + raise AssertionError('this was supposed to raise an exception') + +try: + complex('foo', 1) +except TypeError as e: + assert str(e) == "complex() can't take second arg if first is a string" +else: + raise AssertionError('this was supposed to raise an exception') + +try: + complex(1, 'bar') +except TypeError as e: + assert str(e) == "complex() second arg can't be a string" +else: + raise AssertionError('this was supposed to raise an exception') + +# __nonzero__ + +assert complex(0, 0).__nonzero__() == False +assert complex(.0, .0).__nonzero__() == False +assert complex(0.0, 0.1).__nonzero__() == True +assert complex(1, 0).__nonzero__() == True +assert complex(3.14, -0.001e+5).__nonzero__() == True +assert complex(float('nan'), float('nan')).__nonzero__() == True +assert complex(-float('inf'), float('inf')).__nonzero__() == True + +# __pos__ + +assert complex(0, 0).__pos__() == 0j +assert complex(42, -0.1).__pos__() == (42-0.1j) +assert complex(-1.2, 375E+2).__pos__() == (-1.2+37500j) +assert repr(complex(5, float('nan')).__pos__()) == '(5+nanj)' +assert repr(complex(float('inf'), 0.618).__pos__()) == '(inf+0.618j)' \ No newline at end of file diff --git a/testing/file_test.py b/testing/file_test.py new file mode 100644 index 00000000..6ac7707d --- /dev/null +++ b/testing/file_test.py @@ -0,0 +1,15 @@ +f = open('/tmp/file_test__someunlikelyexistingfile', 'w') +assert f.softspace == 0 + +f.softspace = 1 +assert f.softspace == 1 + +try: + f.softspace = '4321' # should not be converted automatically +except TypeError as e: + if not str(e).endswith('is required'): + raise e # Wrong exception arrived to us! +else: + raise RuntimeError('a TypeError should had raised.') + +assert f.softspace == 1 diff --git a/testing/native_test.py b/testing/native_test.py index cc3e7b10..facde28b 100644 --- a/testing/native_test.py +++ b/testing/native_test.py @@ -14,11 +14,11 @@ # pylint: disable=g-multiple-import -from __go__.math import MaxInt32, Pow10, Signbit -from __go__.strings import Count, IndexAny, Repeat -from __go__.encoding.csv import NewReader as NewCSVReader -from __go__.image import Pt -from __go__.strings import NewReader as NewStringReader +from '__go__/math' import MaxInt32, Pow10, Signbit +from '__go__/strings' import Count, IndexAny, Repeat +from '__go__/encoding/csv' import NewReader as NewCSVReader +from '__go__/image' import Pt +from '__go__/strings' import NewReader as NewStringReader assert Count('foo,bar,baz', ',') == 2 assert IndexAny('foobar', 'obr') == 1 diff --git a/testing/op_test.py b/testing/op_test.py index da7aa302..0773b0ad 100644 --- a/testing/op_test.py +++ b/testing/op_test.py @@ -81,5 +81,28 @@ def TestNeg(): assert -x == -100 +def TestPos(): + x = 12 + assert +x == 12 + + x = 1.1 + assert +x == 1.1 + + x = 0.0 + assert +x == 0.0 + + x = float('inf') + assert math.isinf(+x) + + x = +float('inf') + assert math.isinf(+x) + + x = float('nan') + assert math.isnan(+x) + + x = long(100) + assert +x == 100 + + if __name__ == '__main__': weetest.RunTests() diff --git a/testing/str_test.py b/testing/str_test.py index 643163b9..92e7236f 100644 --- a/testing/str_test.py +++ b/testing/str_test.py @@ -27,6 +27,22 @@ assert "Foo".capitalize() == "Foo" assert "FOO".capitalize() == "Foo" +# Test count +assert "".count("a") == 0 +assert "abcd".count("e") == 0 +assert "abccdef".count("c") == 2 +assert "abba".count("bb") == 1 +assert "abbba".count("bb") == 1 +assert "abbbba".count("bb") == 2 +assert "five".count("") == 5 +assert ("a" * 20).count("a") == 20 + +try: + "".count() + assert AssertionError +except TypeError: + pass + # Test find assert "".find("") == 0 assert "".find("", 1) == -1 @@ -164,6 +180,72 @@ def __int__(self): assert "%x" % 0x1f == "1f" assert "%X" % 0xffff == "FFFF" +vals = [ + ['-16', '-16', ' -16', '-16', '-000000016'], + ['-10', '-10', ' -10', '-10', '-000000010'], + ['-10', '-10', ' -10', '-10', '-000000010'], + ['-16', '-16', ' -16', '-16', ' -16'], + ['-16.000000', '-16.000000', '-16.000000', '-16.000000', '-16.000000'], + ['-16', '-16', ' -16', '-16', ' -16'], + ['-20', '-20', ' -20', '-20', '-000000020'], + ['-10', '-10', ' -10', '-10', '-000000010'], + ['-a', '-a', ' -a', '-a', '-00000000a'], + ['-A', '-A', ' -A', '-A', '-00000000A'], + ['-10', '-10', ' -10', '-10', ' -10'], + ['-10.000000', '-10.000000', '-10.000000', '-10.000000', '-10.000000'], + ['-10', '-10', ' -10', '-10', ' -10'], + ['-12', '-12', ' -12', '-12', '-000000012'], + ['-1', '-1', ' -1', '-1', '-000000001'], + ['-1', '-1', ' -1', '-1', '-000000001'], + ['-1', '-1', ' -1', '-1', '-000000001'], + ['-1', '-1', ' -1', '-1', ' -1'], + ['-1.000000', '-1.000000', ' -1.000000', '-1.000000', '-01.000000'], + ['-1', '-1', ' -1', '-1', ' -1'], + ['-1', '-1', ' -1', '-1', '-000000001'], + ['0', ' 0', ' 0', '00', '0000000000'], + ['0', ' 0', ' 0', '00', '0000000000'], + ['0', ' 0', ' 0', '00', '0000000000'], + ['0', ' 0', ' 0', ' 0', ' 0'], + ['0.000000', '0.000000', ' 0.000000', '0.000000', '000.000000'], + ['0', ' 0', ' 0', ' 0', ' 0'], + ['0', ' 0', ' 0', '00', '0000000000'], + ['1', ' 1', ' 1', '01', '0000000001'], + ['1', ' 1', ' 1', '01', '0000000001'], + ['1', ' 1', ' 1', '01', '0000000001'], + ['1', ' 1', ' 1', ' 1', ' 1'], + ['1.000000', '1.000000', ' 1.000000', '1.000000', '001.000000'], + ['1', ' 1', ' 1', ' 1', ' 1'], + ['1', ' 1', ' 1', '01', '0000000001'], + ['3', ' 3', ' 3', '03', '0000000003'], + ['3', ' 3', ' 3', '03', '0000000003'], + ['3', ' 3', ' 3', '03', '0000000003'], + ['3.14', '3.14', ' 3.14', '3.14', ' 3.14'], + ['3.140000', '3.140000', ' 3.140000', '3.140000', '003.140000'], + ['3.14', '3.14', ' 3.14', '3.14', ' 3.14'], + ['3', ' 3', ' 3', '03', '0000000003'], + ['10', '10', ' 10', '10', '0000000010'], + ['a', ' a', ' a', '0a', '000000000a'], + ['A', ' A', ' A', '0A', '000000000A'], + ['10', '10', ' 10', '10', ' 10'], + ['10.000000', '10.000000', ' 10.000000', '10.000000', '010.000000'], + ['10', '10', ' 10', '10', ' 10'], + ['12', '12', ' 12', '12', '0000000012'], + ['16', '16', ' 16', '16', '0000000016'], + ['10', '10', ' 10', '10', '0000000010'], + ['10', '10', ' 10', '10', '0000000010'], + ['16', '16', ' 16', '16', ' 16'], + ['16.000000', '16.000000', ' 16.000000', '16.000000', '016.000000'], + ['16', '16', ' 16', '16', ' 16'], + ['20', '20', ' 20', '20', '0000000020'], +] + +i = 0 +for a in [-16, -10, -1, 0, 1, 3.14, 10, 16]: + for b in "dxXrfso": + assert [("%" + b) % (a, ), ("%2" + b) % (a, ), ("%10" + b) % (a, ), + ("%02" + b) % (a, ), ("%010" + b) % (a, )] == vals[i] + i += 1 + # Test replace assert 'one!two!three!'.replace('!', '@', 1) == 'one@two!three!' assert 'one!two!three!'.replace('!', '') == 'onetwothree' @@ -283,3 +365,7 @@ def __int__(self): assert '3'.zfill(A()) == '003' assert '3'.zfill(IntIntType()) == '03' assert '3'.zfill(LongIntType()) == '03' + +assert '%o' % 8 == '10' +assert '%o' % -8 == '-10' +assert '%o %o' % (8, -8) == '10 -10' diff --git a/testing/try_test.py b/testing/try_test.py index 4df6f211..e0556bcc 100644 --- a/testing/try_test.py +++ b/testing/try_test.py @@ -220,3 +220,31 @@ def f(): foo() + + +# Return statement should not bypass the finally. +def foo(): + try: + return 1 + finally: + return 2 + return 3 + + +assert foo() == 2 + + +# Break statement should not bypass finally. +x = [] +def foo(): + while True: + try: + x.append(1) + break + finally: + x.append(2) + x.append(3) + + +foo() +assert x == [1, 2, 3] diff --git a/testing/md5_test.py b/testing/tuple_test.py similarity index 73% rename from testing/md5_test.py rename to testing/tuple_test.py index 11e29cd7..bda282ae 100644 --- a/testing/md5_test.py +++ b/testing/tuple_test.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import md5 +# Test count +assert ().count(0) == 0 +assert (1, 2, 3).count(2) == 1 +assert ("a", "b", "a", "a").count("a") == 3 +assert ((2,) * 20).count(2) == 20 -assert md5.new("").hexdigest() == 'd41d8cd98f00b204e9800998ecf8427e' -assert md5.new("hello").hexdigest() == '5d41402abc4b2a76b9719d911017c592' +try: + ().count() + assert AssertionError +except TypeError: + pass diff --git a/testing/with_test.py b/testing/with_test.py index 00545308..04112681 100644 --- a/testing/with_test.py +++ b/testing/with_test.py @@ -180,3 +180,32 @@ def __exit__(self, *args): assert h == 1 assert i == 2 assert j == 3 + + +class Foo(object): + exited = False + def __enter__(self): + pass + def __exit__(self, *args): + self.exited = True + + +# This checks for a bug where a with clause inside an except body raises an +# exception because it was checking ExcInfo() to determine whether an exception +# occurred. +try: + raise AssertionError +except: + foo = Foo() + with foo: + pass + assert foo.exited + + +# Return statement should not bypass the with exit handler. +foo = Foo() +def bar(): + with foo: + return +bar() +assert foo.exited diff --git a/third_party/ouroboros/AUTHORS b/third_party/ouroboros/AUTHORS new file mode 100644 index 00000000..d3d7128d --- /dev/null +++ b/third_party/ouroboros/AUTHORS @@ -0,0 +1,8 @@ +Ouroboros was originally created in Jan 2016. + +The PRIMARY AUTHORS are (and/or have been): + Russell Keith-Magee + +And here is an inevitably incomplete list of MUCH-APPRECIATED CONTRIBUTORS -- +people who have submitted patches, reported bugs, added translations, helped +answer newbie questions, and generally made Ouroboros that much better: diff --git a/third_party/ouroboros/LICENSE b/third_party/ouroboros/LICENSE new file mode 100644 index 00000000..20d5185b --- /dev/null +++ b/third_party/ouroboros/LICENSE @@ -0,0 +1,82 @@ +Copyright (c) 2016 Russell Keith-Magee. + +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. Neither the name of Ouroboros nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +The original Python Standard Library, upon which this project is based, is +released under the terms of the PSF License version 2: + +Copyright (c) 2001-2014 Python Software Foundation; All rights reserved + +PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +-------------------------------------------- + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF hereby +grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +analyze, test, perform and/or display publicly, prepare derivative works, +distribute, and otherwise use Python alone or in any derivative version, +provided, however, that PSF's License Agreement and PSF's notice of copyright, +i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +2011, 2012, 2013, 2014 Python Software Foundation; All Rights Reserved" are +retained in Python alone or in any derivative version prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee. This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. diff --git a/third_party/ouroboros/READEME.md b/third_party/ouroboros/READEME.md new file mode 100644 index 00000000..e8d34e7b --- /dev/null +++ b/third_party/ouroboros/READEME.md @@ -0,0 +1 @@ +The source code in this directory is forked from github.com/pybee/ouroboros. There are very light modifications to the source code so that it will work with Grumpy. diff --git a/third_party/ouroboros/operator.py b/third_party/ouroboros/operator.py new file mode 100644 index 00000000..2211bb40 --- /dev/null +++ b/third_party/ouroboros/operator.py @@ -0,0 +1,415 @@ +""" +Operator Interface +This module exports a set of functions corresponding to the intrinsic +operators of Python. For example, operator.add(x, y) is equivalent +to the expression x+y. The function names are those used for special +methods; variants without leading and trailing '__' are also provided +for convenience. +This is the pure Python implementation of the module. +""" + +__all__ = ['abs', 'add', 'and_', 'attrgetter', 'concat', 'contains', 'countOf', + 'delitem', 'eq', 'floordiv', 'ge', 'getitem', 'gt', 'iadd', 'iand', + 'iconcat', 'ifloordiv', 'ilshift', 'imod', 'imul', 'index', + 'indexOf', 'inv', 'invert', 'ior', 'ipow', 'irshift', 'is_', + 'is_not', 'isub', 'itemgetter', 'itruediv', 'ixor', 'le', + 'length_hint', 'lshift', 'lt', 'methodcaller', 'mod', 'mul', 'ne', + 'neg', 'not_', 'or_', 'pos', 'pow', 'rshift', 'setitem', 'sub', + 'truediv', 'truth', 'xor'] + +from '__go__/math' import Abs as _abs + + +# Comparison Operations *******************************************************# + +def lt(a, b): + "Same as a < b." + return a < b + +def le(a, b): + "Same as a <= b." + return a <= b + +def eq(a, b): + "Same as a == b." + return a == b + +def ne(a, b): + "Same as a != b." + return a != b + +def ge(a, b): + "Same as a >= b." + return a >= b + +def gt(a, b): + "Same as a > b." + return a > b + +# Logical Operations **********************************************************# + +def not_(a): + "Same as not a." + return not a + +def truth(a): + "Return True if a is true, False otherwise." + return True if a else False + +def is_(a, b): + "Same as a is b." + return a is b + +def is_not(a, b): + "Same as a is not b." + return a is not b + +# Mathematical/Bitwise Operations *********************************************# + +def abs(a): + "Same as abs(a)." + return _abs(a) + +def add(a, b): + "Same as a + b." + return a + b + +def and_(a, b): + "Same as a & b." + return a & b + +def floordiv(a, b): + "Same as a // b." + return a // b + +def index(a): + "Same as a.__index__()." + return a.__index__() + +def inv(a): + "Same as ~a." + return ~a +invert = inv + +def lshift(a, b): + "Same as a << b." + return a << b + +def mod(a, b): + "Same as a % b." + return a % b + +def mul(a, b): + "Same as a * b." + return a * b + +def neg(a): + "Same as -a." + return -a + +def or_(a, b): + "Same as a | b." + return a | b + +def pos(a): + "Same as +a." + return +a + +def pow(a, b): + "Same as a ** b." + return a**b + +def rshift(a, b): + "Same as a >> b." + return a >> b + +def sub(a, b): + "Same as a - b." + return a - b + +def truediv(a, b): + "Same as a / b." + if type(a) == int or type(a) == long: + a = float(a) + return a / b + +def xor(a, b): + "Same as a ^ b." + return a ^ b + +# Sequence Operations *********************************************************# + +def concat(a, b): + "Same as a + b, for a and b sequences." + if not hasattr(a, '__getitem__'): + msg = "'%s' object can't be concatenated" % type(a).__name__ + raise TypeError(msg) + return a + b + +def contains(a, b): + "Same as b in a (note reversed operands)." + return b in a + +def countOf(a, b): + "Return the number of times b occurs in a." + count = 0 + for i in a: + if i == b: + count += 1 + return count + +def delitem(a, b): + "Same as del a[b]." + del a[b] + +def getitem(a, b): + "Same as a[b]." + return a[b] + +def indexOf(a, b): + "Return the first index of b in a." + for i, j in enumerate(a): + if j == b: + return i + else: + raise ValueError('sequence.index(x): x not in sequence') + +def setitem(a, b, c): + "Same as a[b] = c." + a[b] = c + +def length_hint(obj, default=0): + """ + Return an estimate of the number of items in obj. + This is useful for presizing containers when building from an iterable. + If the object supports len(), the result will be exact. Otherwise, it may + over- or under-estimate by an arbitrary amount. The result will be an + integer >= 0. + """ + if not isinstance(default, int): + msg = ("'%s' object cannot be interpreted as an integer" % + type(default).__name__) + raise TypeError(msg) + + try: + return len(obj) + except TypeError: + pass + + try: + hint = type(obj).__length_hint__ + except AttributeError: + return default + + try: + val = hint(obj) + except TypeError: + return default + if val is NotImplemented: + return default + if not isinstance(val, int): + msg = ('__length_hint__ must be integer, not %s' % + type(val).__name__) + raise TypeError(msg) + if val < 0: + msg = '__length_hint__() should return >= 0' + raise ValueError(msg) + return val + +# Generalized Lookup Objects **************************************************# + +# TODO: class attrgetter: +class attrgetter(object): + """ + Return a callable object that fetches the given attribute(s) from its operand. + After f = attrgetter('name'), the call f(r) returns r.name. + After g = attrgetter('name', 'date'), the call g(r) returns (r.name, r.date). + After h = attrgetter('name.first', 'name.last'), the call h(r) returns + (r.name.first, r.name.last). + """ + def __init__(self, attr, *attrs): + if not attrs: + if not isinstance(attr, str): + raise TypeError('attribute name must be a string') + names = attr.split('.') + def func(obj): + for name in names: + obj = getattr(obj, name) + return obj + self._call = func + else: + getters = tuple(map(attrgetter, (attr,) + attrs)) + def func(obj): + return tuple(getter(obj) for getter in getters) + self._call = func + + def __call__(self, obj): + return self._call(obj) + +# TODO: class itemgetter: +class itemgetter(object): + """ + Return a callable object that fetches the given item(s) from its operand. + After f = itemgetter(2), the call f(r) returns r[2]. + After g = itemgetter(2, 5, 3), the call g(r) returns (r[2], r[5], r[3]) + """ + def __init__(self, item, *items): + if not items: + def func(obj): + return obj[item] + self._call = func + else: + items = (item,) + items + def func(obj): + return tuple(obj[i] for i in items) + self._call = func + + def __call__(self, obj): + return self._call(obj) + +# TODO: class methodcaller: +class methodcaller(object): + """ + Return a callable object that calls the given method on its operand. + After f = methodcaller('name'), the call f(r) returns r.name(). + After g = methodcaller('name', 'date', foo=1), the call g(r) returns + r.name('date', foo=1). + """ + + def __init__(*args, **kwargs): + if len(args) < 2: + msg = "methodcaller needs at least one argument, the method name" + raise TypeError(msg) + self = args[0] + self._name = args[1] + self._args = args[2:] + self._kwargs = kwargs + + def __call__(self, obj): + return getattr(obj, self._name)(*self._args, **self._kwargs) + +# In-place Operations *********************************************************# + +def iadd(a, b): + "Same as a += b." + a += b + return a + +def iand(a, b): + "Same as a &= b." + a &= b + return a + +def iconcat(a, b): + "Same as a += b, for a and b sequences." + if not hasattr(a, '__getitem__'): + msg = "'%s' object can't be concatenated" % type(a).__name__ + raise TypeError(msg) + a += b + return a + +def ifloordiv(a, b): + "Same as a //= b." + a //= b + return a + +def ilshift(a, b): + "Same as a <<= b." + a <<= b + return a + +def imod(a, b): + "Same as a %= b." + a %= b + return a + +def imul(a, b): + "Same as a *= b." + a *= b + return a + +def ior(a, b): + "Same as a |= b." + a |= b + return a + +def ipow(a, b): + "Same as a **= b." + a **= b + return a + +def irshift(a, b): + "Same as a >>= b." + a >>= b + return a + +def isub(a, b): + "Same as a -= b." + a -= b + return a + +def itruediv(a, b): + "Same as a /= b." + if type(a) == int or type(a) == long: + a = float(a) + a /= b + return a + +def ixor(a, b): + "Same as a ^= b." + a ^= b + return a + +# TODO: https://github.com/google/grumpy/pull/263 +#try: +# from _operator import * +#except ImportError: +# pass +#else: +# from _operator import __doc__ + +# All of these "__func__ = func" assignments have to happen after importing +# from _operator to make sure they're set to the right function +__lt__ = lt +__le__ = le +__eq__ = eq +__ne__ = ne +__ge__ = ge +__gt__ = gt +__not__ = not_ +__abs__ = abs +__add__ = add +__and__ = and_ +__floordiv__ = floordiv +__index__ = index +__inv__ = inv +__invert__ = invert +__lshift__ = lshift +__mod__ = mod +__mul__ = mul +__neg__ = neg +__or__ = or_ +__pos__ = pos +__pow__ = pow +__rshift__ = rshift +__sub__ = sub +__truediv__ = truediv +__xor__ = xor +__concat__ = concat +__contains__ = contains +__delitem__ = delitem +__getitem__ = getitem +__setitem__ = setitem +__iadd__ = iadd +__iand__ = iand +__iconcat__ = iconcat +__ifloordiv__ = ifloordiv +__ilshift__ = ilshift +__imod__ = imod +__imul__ = imul +__ior__ = ior +__ipow__ = ipow +__irshift__ = irshift +__isub__ = isub +__itruediv__ = itruediv +__ixor__ = ixor diff --git a/third_party/ouroboros/test/test_operator.py b/third_party/ouroboros/test/test_operator.py new file mode 100644 index 00000000..3ba34b88 --- /dev/null +++ b/third_party/ouroboros/test/test_operator.py @@ -0,0 +1,484 @@ +import unittest +import operator +from test import test_support + +class Seq1(object): + def __init__(self, lst): + self.lst = lst + def __len__(self): + return len(self.lst) + def __getitem__(self, i): + return self.lst[i] + def __add__(self, other): + return self.lst + other.lst + def __mul__(self, other): + return self.lst * other + def __rmul__(self, other): + return other * self.lst + +class Seq2(object): + def __init__(self, lst): + self.lst = lst + def __len__(self): + return len(self.lst) + def __getitem__(self, i): + return self.lst[i] + def __add__(self, other): + return self.lst + other.lst + def __mul__(self, other): + return self.lst * other + def __rmul__(self, other): + return other * self.lst + +class OperatorTestCase(unittest.TestCase): + def test_lt(self): + #operator = self.module + self.assertRaises(TypeError, operator.lt) + self.assertFalse(operator.lt(1, 0)) + self.assertFalse(operator.lt(1, 0.0)) + self.assertFalse(operator.lt(1, 1)) + self.assertFalse(operator.lt(1, 1.0)) + self.assertTrue(operator.lt(1, 2)) + self.assertTrue(operator.lt(1, 2.0)) + + def test_le(self): + #operator = self.module + self.assertRaises(TypeError, operator.le) + self.assertFalse(operator.le(1, 0)) + self.assertFalse(operator.le(1, 0.0)) + self.assertTrue(operator.le(1, 1)) + self.assertTrue(operator.le(1, 1.0)) + self.assertTrue(operator.le(1, 2)) + self.assertTrue(operator.le(1, 2.0)) + + def test_eq(self): + #operator = self.module + class C(object): + def __eq__(self, other): + raise SyntaxError + self.assertRaises(TypeError, operator.eq) + self.assertRaises(SyntaxError, operator.eq, C(), C()) + self.assertFalse(operator.eq(1, 0)) + self.assertFalse(operator.eq(1, 0.0)) + self.assertTrue(operator.eq(1, 1)) + self.assertTrue(operator.eq(1, 1.0)) + self.assertFalse(operator.eq(1, 2)) + self.assertFalse(operator.eq(1, 2.0)) + + def test_ne(self): + #operator = self.module + class C(object): + def __ne__(self, other): + raise SyntaxError + self.assertRaises(TypeError, operator.ne) + self.assertRaises(SyntaxError, operator.ne, C(), C()) + self.assertTrue(operator.ne(1, 0)) + self.assertTrue(operator.ne(1, 0.0)) + self.assertFalse(operator.ne(1, 1)) + self.assertFalse(operator.ne(1, 1.0)) + self.assertTrue(operator.ne(1, 2)) + self.assertTrue(operator.ne(1, 2.0)) + + def test_ge(self): + #operator = self.module + self.assertRaises(TypeError, operator.ge) + self.assertTrue(operator.ge(1, 0)) + self.assertTrue(operator.ge(1, 0.0)) + self.assertTrue(operator.ge(1, 1)) + self.assertTrue(operator.ge(1, 1.0)) + self.assertFalse(operator.ge(1, 2)) + self.assertFalse(operator.ge(1, 2.0)) + + def test_gt(self): + #operator = self.module + self.assertRaises(TypeError, operator.gt) + self.assertTrue(operator.gt(1, 0)) + self.assertTrue(operator.gt(1, 0.0)) + self.assertFalse(operator.gt(1, 1)) + self.assertFalse(operator.gt(1, 1.0)) + self.assertFalse(operator.gt(1, 2)) + self.assertFalse(operator.gt(1, 2.0)) + + def test_abs(self): + #operator = self.module + self.assertRaises(TypeError, operator.abs) + self.assertRaises(TypeError, operator.abs, None) + self.assertEqual(operator.abs(-1), 1) + self.assertEqual(operator.abs(1), 1) + + def test_add(self): + #operator = self.module + self.assertRaises(TypeError, operator.add) + self.assertRaises(TypeError, operator.add, None, None) + self.assertTrue(operator.add(3, 4) == 7) + + def test_bitwise_and(self): + #operator = self.module + self.assertRaises(TypeError, operator.and_) + self.assertRaises(TypeError, operator.and_, None, None) + self.assertTrue(operator.and_(0xf, 0xa) == 0xa) + + def test_concat(self): + #operator = self.module + self.assertRaises(TypeError, operator.concat) + self.assertRaises(TypeError, operator.concat, None, None) + self.assertTrue(operator.concat('py', 'thon') == 'python') + self.assertTrue(operator.concat([1, 2], [3, 4]) == [1, 2, 3, 4]) + self.assertTrue(operator.concat(Seq1([5, 6]), Seq1([7])) == [5, 6, 7]) + self.assertTrue(operator.concat(Seq2([5, 6]), Seq2([7])) == [5, 6, 7]) + self.assertRaises(TypeError, operator.concat, 13, 29) + + def test_countOf(self): + #operator = self.module + self.assertRaises(TypeError, operator.countOf) + self.assertRaises(TypeError, operator.countOf, None, None) + self.assertTrue(operator.countOf([1, 2, 1, 3, 1, 4], 3) == 1) + self.assertTrue(operator.countOf([1, 2, 1, 3, 1, 4], 5) == 0) + + @unittest.expectedFailure + def test_delitem(self): + #operator = self.module + a = [4, 3, 2, 1] + self.assertRaises(TypeError, operator.delitem, a) + self.assertRaises(TypeError, operator.delitem, a, None) + self.assertTrue(operator.delitem(a, 1) is None) + self.assertTrue(a == [4, 2, 1]) + + def test_floordiv(self): + #operator = self.module + self.assertRaises(TypeError, operator.floordiv, 5) + self.assertRaises(TypeError, operator.floordiv, None, None) + self.assertTrue(operator.floordiv(5, 2) == 2) + + def test_truediv(self): + #operator = self.module + self.assertRaises(TypeError, operator.truediv, 5) + self.assertRaises(TypeError, operator.truediv, None, None) + self.assertTrue(operator.truediv(5, 2) == 2.5) + + def test_getitem(self): + #operator = self.module + a = range(10) + self.assertRaises(TypeError, operator.getitem) + self.assertRaises(TypeError, operator.getitem, a, None) + self.assertTrue(operator.getitem(a, 2) == 2) + + def test_indexOf(self): + #operator = self.module + self.assertRaises(TypeError, operator.indexOf) + self.assertRaises(TypeError, operator.indexOf, None, None) + self.assertTrue(operator.indexOf([4, 3, 2, 1], 3) == 1) + self.assertRaises(ValueError, operator.indexOf, [4, 3, 2, 1], 0) + + def test_invert(self): + #operator = self.module + self.assertRaises(TypeError, operator.invert) + self.assertRaises(TypeError, operator.invert, None) + self.assertEqual(operator.inv(4), -5) + + def test_lshift(self): + #operator = self.module + self.assertRaises(TypeError, operator.lshift) + self.assertRaises(TypeError, operator.lshift, None, 42) + self.assertTrue(operator.lshift(5, 1) == 10) + self.assertTrue(operator.lshift(5, 0) == 5) + self.assertRaises(ValueError, operator.lshift, 2, -1) + + def test_mod(self): + #operator = self.module + self.assertRaises(TypeError, operator.mod) + self.assertRaises(TypeError, operator.mod, None, 42) + self.assertTrue(operator.mod(5, 2) == 1) + + def test_mul(self): + #operator = self.module + self.assertRaises(TypeError, operator.mul) + self.assertRaises(TypeError, operator.mul, None, None) + self.assertTrue(operator.mul(5, 2) == 10) + + def test_neg(self): + #operator = self.module + self.assertRaises(TypeError, operator.neg) + self.assertRaises(TypeError, operator.neg, None) + self.assertEqual(operator.neg(5), -5) + self.assertEqual(operator.neg(-5), 5) + self.assertEqual(operator.neg(0), 0) + self.assertEqual(operator.neg(-0), 0) + + def test_bitwise_or(self): + #operator = self.module + self.assertRaises(TypeError, operator.or_) + self.assertRaises(TypeError, operator.or_, None, None) + self.assertTrue(operator.or_(0xa, 0x5) == 0xf) + + def test_pos(self): + #operator = self.module + self.assertRaises(TypeError, operator.pos) + self.assertRaises(TypeError, operator.pos, None) + self.assertEqual(operator.pos(5), 5) + self.assertEqual(operator.pos(-5), -5) + self.assertEqual(operator.pos(0), 0) + self.assertEqual(operator.pos(-0), 0) + + def test_pow(self): + #operator = self.module + self.assertRaises(TypeError, operator.pow) + self.assertRaises(TypeError, operator.pow, None, None) + self.assertEqual(operator.pow(3,5), 3**5) + self.assertRaises(TypeError, operator.pow, 1) + self.assertRaises(TypeError, operator.pow, 1, 2, 3) + + def test_rshift(self): + #operator = self.module + self.assertRaises(TypeError, operator.rshift) + self.assertRaises(TypeError, operator.rshift, None, 42) + self.assertTrue(operator.rshift(5, 1) == 2) + self.assertTrue(operator.rshift(5, 0) == 5) + self.assertRaises(ValueError, operator.rshift, 2, -1) + + def test_contains(self): + #operator = self.module + self.assertRaises(TypeError, operator.contains) + self.assertRaises(TypeError, operator.contains, None, None) + self.assertTrue(operator.contains(range(4), 2)) + self.assertFalse(operator.contains(range(4), 5)) + + def test_setitem(self): + #operator = self.module + a = list(range(3)) + self.assertRaises(TypeError, operator.setitem, a) + self.assertRaises(TypeError, operator.setitem, a, None, None) + self.assertTrue(operator.setitem(a, 0, 2) is None) + self.assertTrue(a == [2, 1, 2]) + self.assertRaises(IndexError, operator.setitem, a, 4, 2) + + def test_sub(self): + #operator = self.module + self.assertRaises(TypeError, operator.sub) + self.assertRaises(TypeError, operator.sub, None, None) + self.assertTrue(operator.sub(5, 2) == 3) + + @unittest.expectedFailure + def test_truth(self): + #operator = self.module + class C(object): + def __bool__(self): + raise SyntaxError + self.assertRaises(TypeError, operator.truth) + self.assertRaises(SyntaxError, operator.truth, C()) + self.assertTrue(operator.truth(5)) + self.assertTrue(operator.truth([0])) + self.assertFalse(operator.truth(0)) + self.assertFalse(operator.truth([])) + + def test_bitwise_xor(self): + #operator = self.module + self.assertRaises(TypeError, operator.xor) + self.assertRaises(TypeError, operator.xor, None, None) + self.assertTrue(operator.xor(0xb, 0xc) == 0x7) + + def test_is(self): + #operator = self.module + a = b = 'xyzpdq' + c = a[:3] + b[3:] + self.assertRaises(TypeError, operator.is_) + self.assertTrue(operator.is_(a, b)) + #self.assertFalse(operator.is_(a,c)) + + @unittest.expectedFailure + def test_is_not(self): + #operator = self.module + a = b = 'xyzpdq' + c = a[:3] + b[3:] + self.assertRaises(TypeError, operator.is_not) + self.assertFalse(operator.is_not(a, b)) + self.assertTrue(operator.is_not(a,c)) + + @unittest.expectedFailure + def test_attrgetter(self): + #operator = self.module + class A(object): + pass + a = A() + a.name = 'arthur' + f = operator.attrgetter('name') + self.assertEqual(f(a), 'arthur') + f = operator.attrgetter('rank') + self.assertRaises(AttributeError, f, a) + self.assertRaises(TypeError, operator.attrgetter, 2) + self.assertRaises(TypeError, operator.attrgetter) + + # multiple gets + record = A() + record.x = 'X' + record.y = 'Y' + record.z = 'Z' + self.assertEqual(operator.attrgetter('x','z','y')(record), ('X', 'Z', 'Y')) + self.assertRaises(TypeError, operator.attrgetter, ('x', (), 'y')) + + class C(object): + def __getattr__(self, name): + raise SyntaxError + self.assertRaises(SyntaxError, operator.attrgetter('foo'), C()) + + # recursive gets + a = A() + a.name = 'arthur' + a.child = A() + a.child.name = 'thomas' + f = operator.attrgetter('child.name') + self.assertEqual(f(a), 'thomas') + self.assertRaises(AttributeError, f, a.child) + f = operator.attrgetter('name', 'child.name') + self.assertEqual(f(a), ('arthur', 'thomas')) + f = operator.attrgetter('name', 'child.name', 'child.child.name') + self.assertRaises(AttributeError, f, a) + f = operator.attrgetter('child.') + self.assertRaises(AttributeError, f, a) + f = operator.attrgetter('.child') + self.assertRaises(AttributeError, f, a) + + a.child.child = A() + a.child.child.name = 'johnson' + f = operator.attrgetter('child.child.name') + self.assertEqual(f(a), 'johnson') + f = operator.attrgetter('name', 'child.name', 'child.child.name') + self.assertEqual(f(a), ('arthur', 'thomas', 'johnson')) + + @unittest.expectedFailure + def test_itemgetter(self): + #operator = self.module + a = 'ABCDE' + f = operator.itemgetter(2) + self.assertEqual(f(a), 'C') + f = operator.itemgetter(10) + self.assertRaises(IndexError, f, a) + + class C(object): + def __getitem__(self, name): + raise SyntaxError + self.assertRaises(SyntaxError, operator.itemgetter(42), C()) + + f = operator.itemgetter('name') + self.assertRaises(TypeError, f, a) + self.assertRaises(TypeError, operator.itemgetter) + + d = dict(key='val') + f = operator.itemgetter('key') + self.assertEqual(f(d), 'val') + f = operator.itemgetter('nonkey') + self.assertRaises(KeyError, f, d) + + # example used in the docs + inventory = [('apple', 3), ('banana', 2), ('pear', 5), ('orange', 1)] + getcount = operator.itemgetter(1) + self.assertEqual(list(map(getcount, inventory)), [3, 2, 5, 1]) + self.assertEqual(sorted(inventory, key=getcount), + [('orange', 1), ('banana', 2), ('apple', 3), ('pear', 5)]) + + # multiple gets + data = list(map(str, range(20))) + self.assertEqual(operator.itemgetter(2,10,5)(data), ('2', '10', '5')) + self.assertRaises(TypeError, operator.itemgetter(2, 'x', 5), data) + + def test_methodcaller(self): + #operator = self.module + self.assertRaises(TypeError, operator.methodcaller) + class A(object): + def foo(self, *args, **kwds): + return args[0] + args[1] + def bar(self, f=42): + return f + def baz(*args, **kwds): + return kwds['name'], kwds['self'] + a = A() + f = operator.methodcaller('foo') + self.assertRaises(IndexError, f, a) + f = operator.methodcaller('foo', 1, 2) + self.assertEqual(f(a), 3) + f = operator.methodcaller('bar') + self.assertEqual(f(a), 42) + self.assertRaises(TypeError, f, a, a) + f = operator.methodcaller('bar', f=5) + self.assertEqual(f(a), 5) + f = operator.methodcaller('baz', name='spam', self='eggs') + self.assertEqual(f(a), ('spam', 'eggs')) + + @unittest.expectedFailure + def test_inplace(self): + #operator = self.module + class C(object): + def __iadd__ (self, other): return "iadd" + def __iand__ (self, other): return "iand" + def __ifloordiv__(self, other): return "ifloordiv" + def __ilshift__ (self, other): return "ilshift" + def __imod__ (self, other): return "imod" + def __imul__ (self, other): return "imul" + def __ior__ (self, other): return "ior" + def __ipow__ (self, other): return "ipow" + def __irshift__ (self, other): return "irshift" + def __isub__ (self, other): return "isub" + def __itruediv__ (self, other): return "itruediv" + def __ixor__ (self, other): return "ixor" + def __getitem__(self, other): return 5 # so that C is a sequence + c = C() + self.assertEqual(operator.iadd (c, 5), "iadd") + self.assertEqual(operator.iand (c, 5), "iand") + self.assertEqual(operator.ifloordiv(c, 5), "ifloordiv") + self.assertEqual(operator.ilshift (c, 5), "ilshift") + self.assertEqual(operator.imod (c, 5), "imod") + self.assertEqual(operator.imul (c, 5), "imul") + self.assertEqual(operator.ior (c, 5), "ior") + self.assertEqual(operator.ipow (c, 5), "ipow") + self.assertEqual(operator.irshift (c, 5), "irshift") + self.assertEqual(operator.isub (c, 5), "isub") + self.assertEqual(operator.itruediv (c, 5), "itruediv") + self.assertEqual(operator.ixor (c, 5), "ixor") + self.assertEqual(operator.iconcat (c, c), "iadd") + + @unittest.expectedFailure + def test_length_hint(self): + #operator = self.module + class X(object): + def __init__(self, value): + self.value = value + + def __length_hint__(self): + if type(self.value) is type: + raise self.value + else: + return self.value + + self.assertEqual(operator.length_hint([], 2), 0) + self.assertEqual(operator.length_hint(iter([1, 2, 3])), 3) + + self.assertEqual(operator.length_hint(X(2)), 2) + self.assertEqual(operator.length_hint(X(NotImplemented), 4), 4) + self.assertEqual(operator.length_hint(X(TypeError), 12), 12) + with self.assertRaises(TypeError): + operator.length_hint(X("abc")) + with self.assertRaises(ValueError): + operator.length_hint(X(-2)) + with self.assertRaises(LookupError): + operator.length_hint(X(LookupError)) + + def test_dunder_is_original(self): + #operator = self.module + + names = [name for name in dir(operator) if not name.startswith('_')] + for name in names: + orig = getattr(operator, name) + dunder = getattr(operator, '__' + name.strip('_') + '__', None) + if dunder: + self.assertIs(dunder, orig) + + def test_complex_operator(self): + self.assertRaises(TypeError, operator.lt, 1j, 2j) + self.assertRaises(TypeError, operator.le, 1j, 2j) + self.assertRaises(TypeError, operator.ge, 1j, 2j) + self.assertRaises(TypeError, operator.gt, 1j, 2j) +def test_main(): + test_support.run_unittest(OperatorTestCase) + +if __name__ == "__main__": + test_main() diff --git a/third_party/pypy/LICENSE b/third_party/pypy/LICENSE new file mode 100644 index 00000000..06334b4d --- /dev/null +++ b/third_party/pypy/LICENSE @@ -0,0 +1,486 @@ +#encoding utf-8 + +License +======= + +Except when otherwise stated (look for LICENSE files in directories or +information at the beginning of each file) all software and documentation in +the 'rpython', 'pypy', 'ctype_configure', 'dotviewer', 'demo', 'lib_pypy', +'py', and '_pytest' directories is licensed as follows: + + The MIT License + + Permission is hereby granted, free of charge, to any person + obtaining a copy of this software and associated documentation + files (the "Software"), to deal in the Software without + restriction, including without limitation the rights to use, + copy, modify, merge, publish, distribute, sublicense, and/or + sell copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. + + +PyPy Copyright holders 2003-2017 +----------------------------------- + +Except when otherwise stated (look for LICENSE files or information at +the beginning of each file) the files in the 'pypy' directory are each +copyrighted by one or more of the following people and organizations: + + Armin Rigo + Maciej Fijalkowski + Carl Friedrich Bolz + Amaury Forgeot d'Arc + Antonio Cuni + Samuele Pedroni + Matti Picus + Ronan Lamy + Alex Gaynor + Philip Jenvey + Brian Kearns + Richard Plangger + Michael Hudson + Manuel Jacob + David Schneider + Holger Krekel + Christian Tismer + Hakan Ardo + Benjamin Peterson + Anders Chrigstrom + Wim Lavrijsen + Eric van Riet Paap + Richard Emslie + Alexander Schremmer + Remi Meier + Dan Villiom Podlaski Christiansen + Lukas Diekmann + Sven Hager + Anders Lehmann + Aurelien Campeas + Niklaus Haldimann + Camillo Bruni + Laura Creighton + Romain Guillebert + Toon Verwaest + Leonardo Santagada + Seo Sanghyeon + Ronny Pfannschmidt + Justin Peel + Raffael Tfirst + David Edelsohn + Anders Hammarquist + Jakub Gustak + Gregor Wegberg + Guido Wesdorp + Lawrence Oluyede + Bartosz Skowron + Daniel Roberts + Adrien Di Mascio + Niko Matsakis + Alexander Hesse + Ludovic Aubry + Jacob Hallen + Jason Creighton + Mark Young + Alex Martelli + Spenser Bauman + Michal Bendowski + stian + Jan de Mooij + Tyler Wade + Vincent Legoll + Michael Foord + Stephan Diehl + Stefan Schwarzer + Tomek Meka + Valentino Volonghi + Stefano Rivera + Patrick Maupin + Devin Jeanpierre + Bob Ippolito + Bruno Gola + David Malcolm + Jean-Paul Calderone + Edd Barrett + Squeaky + Timo Paulssen + Marius Gedminas + Alexandre Fayolle + Simon Burton + Nicolas Truessel + Martin Matusiak + Wenzhu Man + Konstantin Lopuhin + John Witulski + Laurence Tratt + Greg Price + Ivan Sichmann Freitas + Dario Bertini + Jeremy Thurgood + Mark Pearse + Simon Cross + Tobias Pape + Andreas Stührk + Jean-Philippe St. Pierre + Guido van Rossum + Pavel Vinogradov + Paweł Piotr Przeradowski + William Leslie + marky1991 + Ilya Osadchiy + Tobias Oberstein + Paul deGrandis + Boris Feigin + Taavi Burns + Adrian Kuhn + tav + Georg Brandl + Bert Freudenberg + Stian Andreassen + Wanja Saatkamp + Mike Blume + Gerald Klix + Oscar Nierstrasz + Rami Chowdhury + Stefan H. Muller + Joannah Nanjekye + Eugene Oden + Tim Felgentreff + Jeff Terrace + Henry Mason + Vasily Kuznetsov + Preston Timmons + David Ripton + Dusty Phillips + Lukas Renggli + Guenter Jantzen + Ned Batchelder + Amit Regmi + Anton Gulenko + Sergey Matyunin + Jasper Schulz + Andrew Chambers + Nicolas Chauvat + Andrew Durdin + Ben Young + Michael Schneider + Nicholas Riley + Jason Chu + Igor Trindade Oliveira + Yichao Yu + Michael Twomey + Rocco Moretti + Gintautas Miliauskas + Lucian Branescu Mihaila + anatoly techtonik + Karl Bartel + Gabriel Lavoie + Jared Grubb + Olivier Dormond + Wouter van Heyst + Sebastian Pawluś + Brian Dorsey + Victor Stinner + Andrews Medina + Aaron Iles + Toby Watson + Daniel Patrick + Stuart Williams + Antoine Pitrou + Christian Hudon + Justas Sadzevicius + Neil Shepperd + Michael Cheng + Mikael Schönenberg + Stanislaw Halik + Berkin Ilbeyi + Gasper Zejn + Faye Zhao + Elmo Mäntynen + Anders Qvist + Corbin Simpson + Chirag Jadwani + Jonathan David Riehl + Beatrice During + Alex Perry + p_zieschang@yahoo.de + Robert Zaremba + Alan McIntyre + Alexander Sedov + Vaibhav Sood + Reuben Cummings + Attila Gobi + Christopher Pope + Tristan Arthur + Christian Tismer + Dan Stromberg + Carl Meyer + Florin Papa + Valentina Mukhamedzhanova + Stefano Parmesan + touilleMan + Marc Abramowitz + Arjun Naik + Aaron Gallagher + Alexis Daboville + Pieter Zieschang + Karl Ramm + Lukas Vacek + Omer Katz + Jacek Generowicz + Sylvain Thenault + Jakub Stasiak + Stefan Beyer + Andrew Dalke + Alejandro J. Cura + Vladimir Kryachko + Gabriel + Mark Williams + Kunal Grover + Nathan Taylor + Travis Francis Athougies + Yasir Suhail + Sergey Kishchenko + Martin Blais + Lutz Paelike + Ian Foote + Philipp Rustemeuer + Catalin Gabriel Manciu + Jacob Oscarson + Ryan Gonzalez + Kristjan Valur Jonsson + Lucio Torre + Richard Lancaster + Dan Buch + Lene Wagner + Tomo Cocoa + Alecsandru Patrascu + David Lievens + Neil Blakey-Milner + Henrik Vendelbo + Lars Wassermann + Ignas Mikalajunas + Christoph Gerum + Miguel de Val Borro + Artur Lisiecki + Toni Mattis + Laurens Van Houtven + Bobby Impollonia + Roberto De Ioris + Jeong YunWon + Christopher Armstrong + Aaron Tubbs + Vasantha Ganesh K + Jason Michalski + Markus Holtermann + Andrew Thompson + Yusei Tahara + Ruochen Huang + Fabio Niephaus + Akira Li + Gustavo Niemeyer + Rafał Gałczyński + Logan Chien + Lucas Stadler + roberto@goyle + Matt Bogosian + Yury V. Zaytsev + florinpapa + Anders Sigfridsson + Nikolay Zinov + rafalgalczynski@gmail.com + Joshua Gilbert + Anna Katrina Dominguez + Kim Jin Su + Amber Brown + Ben Darnell + Juan Francisco Cantero Hurtado + Godefroid Chappelle + Julian Berman + Michael Hudson-Doyle + Floris Bruynooghe + Stephan Busemann + Dan Colish + timo + Volodymyr Vladymyrov + Daniel Neuhäuser + Flavio Percoco + halgari + Jim Baker + Chris Lambacher + coolbutuseless@gmail.com + Mike Bayer + Rodrigo Araújo + Daniil Yarancev + OlivierBlanvillain + Jonas Pfannschmidt + Zearin + Andrey Churin + Dan Crosta + reubano@gmail.com + Julien Phalip + Roman Podoliaka + Eli Stevens + Boglarka Vezer + PavloKapyshin + Tomer Chachamu + Christopher Groskopf + Asmo Soinio + Antony Lee + Jim Hunziker + shoma hosaka + Buck Golemon + JohnDoe + yrttyr + Michael Chermside + Anna Ravencroft + remarkablerocket + Berker Peksag + Christian Muirhead + soareschen + Matthew Miller + Konrad Delong + Dinu Gherman + pizi + James Robert + Armin Ronacher + Diana Popa + Mads Kiilerich + Brett Cannon + aliceinwire + Zooko Wilcox-O Hearn + James Lan + jiaaro + Markus Unterwaditzer + Kristoffer Kleine + Graham Markall + Dan Loewenherz + werat + Niclas Olofsson + Chris Pressey + Tobias Diaz + Nikolaos-Digenis Karagiannis + Kurt Griffiths + Ben Mather + Donald Stufft + Dan Sanders + Jason Madden + Yaroslav Fedevych + Even Wiik Thomassen + Stefan Marr + + Heinrich-Heine University, Germany + Open End AB (formerly AB Strakt), Sweden + merlinux GmbH, Germany + tismerysoft GmbH, Germany + Logilab Paris, France + DFKI GmbH, Germany + Impara, Germany + Change Maker, Sweden + University of California Berkeley, USA + Google Inc. + King's College London + +The PyPy Logo as used by http://speed.pypy.org and others was created +by Samuel Reis and is distributed on terms of Creative Commons Share Alike +License. + +License for 'lib-python/2.7' +============================ + +Except when otherwise stated (look for LICENSE files or copyright/license +information at the beginning of each file) the files in the 'lib-python/2.7' +directory are all copyrighted by the Python Software Foundation and licensed +under the terms that you can find here: https://docs.python.org/2/license.html + +License for 'pypy/module/unicodedata/' +====================================== + +The following files are from the website of The Unicode Consortium +at http://www.unicode.org/. For the terms of use of these files, see +http://www.unicode.org/terms_of_use.html . Or they are derived from +files from the above website, and the same terms of use apply. + + CompositionExclusions-*.txt + EastAsianWidth-*.txt + LineBreak-*.txt + UnicodeData-*.txt + UnihanNumeric-*.txt + +License for 'dotviewer/font/' +============================= + +Copyright (C) 2008 The Android Open Source Project + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Detailed license information is contained in the NOTICE file in the +directory. + + +Licenses and Acknowledgements for Incorporated Software +======================================================= + +This section is an incomplete, but growing list of licenses and +acknowledgements for third-party software incorporated in the PyPy +distribution. + +License for 'Tcl/Tk' +-------------------- + +This copy of PyPy contains library code that may, when used, result in +the Tcl/Tk library to be loaded. PyPy also includes code that may be +regarded as being a copy of some parts of the Tcl/Tk header files. +You may see a copy of the License for Tcl/Tk in the file +`lib_pypy/_tkinter/license.terms` included here. + +License for 'bzip2' +------------------- + +This copy of PyPy may be linked (dynamically or statically) with the +bzip2 library. You may see a copy of the License for bzip2/libbzip2 at + + http://www.bzip.org/1.0.5/bzip2-manual-1.0.5.html + +License for 'openssl' +--------------------- + +This copy of PyPy may be linked (dynamically or statically) with the +openssl library. You may see a copy of the License for OpenSSL at + + https://www.openssl.org/source/license.html + +License for 'gdbm' +------------------ + +The gdbm module includes code from gdbm.h, which is distributed under +the terms of the GPL license version 2 or any later version. Thus the +gdbm module, provided in the file lib_pypy/gdbm.py, is redistributed +under the terms of the GPL license as well. + +License for 'rpython/rlib/rvmprof/src' +-------------------------------------- + +The code is based on gperftools. You may see a copy of the License for it at + + https://github.com/gperftools/gperftools/blob/master/COPYING diff --git a/third_party/pypy/README.md b/third_party/pypy/README.md new file mode 100644 index 00000000..f4cc1621 --- /dev/null +++ b/third_party/pypy/README.md @@ -0,0 +1,3 @@ +Canonical versions of the files in this folder come from the +[lib-python/2.7](https://bitbucket.org/pypy/pypy/src/23fd2966aada422b331d7d752fc383178deffb27/lib-python/2.7/?at=default) +directory of the [PyPy repo](https://bitbucket.org/pypy/pypy). diff --git a/third_party/pypy/_sha512.py b/third_party/pypy/_sha512.py index 9bde3b65..eec167e2 100644 --- a/third_party/pypy/_sha512.py +++ b/third_party/pypy/_sha512.py @@ -277,17 +277,17 @@ def copy(self): return new def test(): - import _sha512 +# import _sha512 a_str = "just a test string" - assert _sha512.sha512().hexdigest() == sha512().hexdigest() - assert _sha512.sha512(a_str).hexdigest() == sha512(a_str).hexdigest() - assert _sha512.sha512(a_str*7).hexdigest() == sha512(a_str*7).hexdigest() + assert sha512().hexdigest() == sha512().hexdigest() + assert sha512(a_str).hexdigest() == sha512(a_str).hexdigest() + assert sha512(a_str*7).hexdigest() == sha512(a_str*7).hexdigest() s = sha512(a_str) s.update(a_str) - assert _sha512.sha512(a_str+a_str).hexdigest() == s.hexdigest() + assert sha512(a_str+a_str).hexdigest() == s.hexdigest() if __name__ == "__main__": test() diff --git a/third_party/pypy/datetime.py b/third_party/pypy/datetime.py new file mode 100644 index 00000000..fe99e1e3 --- /dev/null +++ b/third_party/pypy/datetime.py @@ -0,0 +1,2102 @@ +"""Concrete date/time and related types -- prototype implemented in Python. + +See http://www.zope.org/Members/fdrake/DateTimeWiki/FrontPage + +See also http://dir.yahoo.com/Reference/calendars/ + +For a primer on DST, including many current DST rules, see +http://webexhibits.org/daylightsaving/ + +For more about DST than you ever wanted to know, see +ftp://elsie.nci.nih.gov/pub/ + +Sources for time zone and DST data: http://www.twinsun.com/tz/tz-link.htm + +This was originally copied from the sandbox of the CPython CVS repository. +Thanks to Tim Peters for suggesting using it. +""" + +# from __future__ import division +import time as _time +import math as _math +# import struct as _struct +import _struct + +def divmod(x, y): + x, y = int(x), int(y) + return x / y, x % y + +_SENTINEL = object() + +def _cmp(x, y): + return 0 if x == y else 1 if x > y else -1 + +def _round(x): + return int(_math.floor(x + 0.5) if x >= 0.0 else _math.ceil(x - 0.5)) + +MINYEAR = 1 +MAXYEAR = 9999 +_MINYEARFMT = 1900 + +_MAX_DELTA_DAYS = 999999999 + +# Utility functions, adapted from Python's Demo/classes/Dates.py, which +# also assumes the current Gregorian calendar indefinitely extended in +# both directions. Difference: Dates.py calls January 1 of year 0 day +# number 1. The code here calls January 1 of year 1 day number 1. This is +# to match the definition of the "proleptic Gregorian" calendar in Dershowitz +# and Reingold's "Calendrical Calculations", where it's the base calendar +# for all computations. See the book for algorithms for converting between +# proleptic Gregorian ordinals and many other calendar systems. + +_DAYS_IN_MONTH = [-1, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] + +_DAYS_BEFORE_MONTH = [-1] +dbm = 0 +for dim in _DAYS_IN_MONTH[1:]: + _DAYS_BEFORE_MONTH.append(dbm) + dbm += dim +del dbm, dim + +def _is_leap(year): + "year -> 1 if leap year, else 0." + return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0) + +def _days_before_year(year): + "year -> number of days before January 1st of year." + y = year - 1 + return y*365 + y//4 - y//100 + y//400 + +def _days_in_month(year, month): + "year, month -> number of days in that month in that year." + assert 1 <= month <= 12, month + if month == 2 and _is_leap(year): + return 29 + return _DAYS_IN_MONTH[month] + +def _days_before_month(year, month): + "year, month -> number of days in year preceding first day of month." + assert 1 <= month <= 12, 'month must be in 1..12' + return _DAYS_BEFORE_MONTH[month] + (month > 2 and _is_leap(year)) + +def _ymd2ord(year, month, day): + "year, month, day -> ordinal, considering 01-Jan-0001 as day 1." + assert 1 <= month <= 12, 'month must be in 1..12' + dim = _days_in_month(year, month) + assert 1 <= day <= dim, ('day must be in 1..%d' % dim) + return (_days_before_year(year) + + _days_before_month(year, month) + + day) + +_DI400Y = _days_before_year(401) # number of days in 400 years +_DI100Y = _days_before_year(101) # " " " " 100 " +_DI4Y = _days_before_year(5) # " " " " 4 " + +# A 4-year cycle has an extra leap day over what we'd get from pasting +# together 4 single years. +assert _DI4Y == 4 * 365 + 1 + +# Similarly, a 400-year cycle has an extra leap day over what we'd get from +# pasting together 4 100-year cycles. +assert _DI400Y == 4 * _DI100Y + 1 + +# OTOH, a 100-year cycle has one fewer leap day than we'd get from +# pasting together 25 4-year cycles. +assert _DI100Y == 25 * _DI4Y - 1 + +_US_PER_US = 1 +_US_PER_MS = 1000 +_US_PER_SECOND = 1000000 +_US_PER_MINUTE = 60000000 +_SECONDS_PER_DAY = 24 * 3600 +_US_PER_HOUR = 3600000000 +_US_PER_DAY = 86400000000 +_US_PER_WEEK = 604800000000 + +def _ord2ymd(n): + "ordinal -> (year, month, day), considering 01-Jan-0001 as day 1." + + # n is a 1-based index, starting at 1-Jan-1. The pattern of leap years + # repeats exactly every 400 years. The basic strategy is to find the + # closest 400-year boundary at or before n, then work with the offset + # from that boundary to n. Life is much clearer if we subtract 1 from + # n first -- then the values of n at 400-year boundaries are exactly + # those divisible by _DI400Y: + # + # D M Y n n-1 + # -- --- ---- ---------- ---------------- + # 31 Dec -400 -_DI400Y -_DI400Y -1 + # 1 Jan -399 -_DI400Y +1 -_DI400Y 400-year boundary + # ... + # 30 Dec 000 -1 -2 + # 31 Dec 000 0 -1 + # 1 Jan 001 1 0 400-year boundary + # 2 Jan 001 2 1 + # 3 Jan 001 3 2 + # ... + # 31 Dec 400 _DI400Y _DI400Y -1 + # 1 Jan 401 _DI400Y +1 _DI400Y 400-year boundary + n -= 1 + n400, n = divmod(n, _DI400Y) + year = n400 * 400 + 1 # ..., -399, 1, 401, ... + + # Now n is the (non-negative) offset, in days, from January 1 of year, to + # the desired date. Now compute how many 100-year cycles precede n. + # Note that it's possible for n100 to equal 4! In that case 4 full + # 100-year cycles precede the desired day, which implies the desired + # day is December 31 at the end of a 400-year cycle. + n100, n = divmod(n, _DI100Y) + + # Now compute how many 4-year cycles precede it. + n4, n = divmod(n, _DI4Y) + + # And now how many single years. Again n1 can be 4, and again meaning + # that the desired day is December 31 at the end of the 4-year cycle. + n1, n = divmod(n, 365) + + year += n100 * 100 + n4 * 4 + n1 + if n1 == 4 or n100 == 4: + assert n == 0 + return year-1, 12, 31 + + # Now the year is correct, and n is the offset from January 1. We find + # the month via an estimate that's either exact or one too large. + leapyear = n1 == 3 and (n4 != 24 or n100 == 3) + assert leapyear == _is_leap(year) + month = (n + 50) >> 5 + preceding = _DAYS_BEFORE_MONTH[month] + (month > 2 and leapyear) + if preceding > n: # estimate is too large + month -= 1 + preceding -= _DAYS_IN_MONTH[month] + (month == 2 and leapyear) + n -= preceding + assert 0 <= n < _days_in_month(year, month) + + # Now the year and month are correct, and n is the offset from the + # start of that month: we're done! + return year, month, n+1 + +# Month and day names. For localized versions, see the calendar module. +_MONTHNAMES = [None, "Jan", "Feb", "Mar", "Apr", "May", "Jun", + "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"] +_DAYNAMES = [None, "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] + + +def _build_struct_time(y, m, d, hh, mm, ss, dstflag): + wday = (_ymd2ord(y, m, d) + 6) % 7 + dnum = _days_before_month(y, m) + d + return _time.struct_time((y, m, d, hh, mm, ss, wday, dnum, dstflag)) + +def _format_time(hh, mm, ss, us): + # Skip trailing microseconds when us==0. + result = "%02d:%02d:%02d" % (hh, mm, ss) + if us: + result += ".%06d" % us + return result + +# Correctly substitute for %z and %Z escapes in strftime formats. +# def _wrap_strftime(object, format, timetuple): +# year = timetuple[0] +# if year < _MINYEARFMT: +# raise ValueError("year=%d is before %d; the datetime strftime() " +# "methods require year >= %d" % +# (year, _MINYEARFMT, _MINYEARFMT)) +# # Don't call utcoffset() or tzname() unless actually needed. +# freplace = None # the string to use for %f +# zreplace = None # the string to use for %z +# Zreplace = None # the string to use for %Z + +# # Scan format for %z and %Z escapes, replacing as needed. +# newformat = [] +# push = newformat.append +# i, n = 0, len(format) +# while i < n: +# ch = format[i] +# i += 1 +# if ch == '%': +# if i < n: +# ch = format[i] +# i += 1 +# if ch == 'f': +# if freplace is None: +# freplace = '%06d' % getattr(object, +# 'microsecond', 0) +# newformat.append(freplace) +# elif ch == 'z': +# if zreplace is None: +# zreplace = "" +# if hasattr(object, "_utcoffset"): +# offset = object._utcoffset() +# if offset is not None: +# sign = '+' +# if offset < 0: +# offset = -offset +# sign = '-' +# h, m = divmod(offset, 60) +# zreplace = '%c%02d%02d' % (sign, h, m) +# assert '%' not in zreplace +# newformat.append(zreplace) +# elif ch == 'Z': +# if Zreplace is None: +# Zreplace = "" +# if hasattr(object, "tzname"): +# s = object.tzname() +# if s is not None: +# # strftime is going to have at this: escape % +# Zreplace = s.replace('%', '%%') +# newformat.append(Zreplace) +# else: +# push('%') +# push(ch) +# else: +# push('%') +# else: +# push(ch) +# newformat = "".join(newformat) +# return _time.strftime(newformat, timetuple) + +# Just raise TypeError if the arg isn't None or a string. +def _check_tzname(name): + if name is not None and not isinstance(name, str): + raise TypeError("tzinfo.tzname() must return None or string, " + "not '%s'" % type(name)) + +# name is the offset-producing method, "utcoffset" or "dst". +# offset is what it returned. +# If offset isn't None or timedelta, raises TypeError. +# If offset is None, returns None. +# Else offset is checked for being in range, and a whole # of minutes. +# If it is, its integer value is returned. Else ValueError is raised. +def _check_utc_offset(name, offset): + assert name in ("utcoffset", "dst") + if offset is None: + return + if not isinstance(offset, timedelta): + raise TypeError("tzinfo.%s() must return None " + "or timedelta, not '%s'" % (name, type(offset))) + days = offset.days + if days < -1 or days > 0: + offset = 1440 # trigger out-of-range + else: + seconds = days * 86400 + offset.seconds + minutes, seconds = divmod(seconds, 60) + if seconds or offset.microseconds: + raise ValueError("tzinfo.%s() must return a whole number " + "of minutes" % name) + offset = minutes + if not -1440 < offset < 1440: + raise ValueError("%s()=%d, must be in -1439..1439" % (name, offset)) + return offset + +def _check_int_field(value): + if isinstance(value, int): + return int(value) + if not isinstance(value, float): + try: + value = value.__int__() + except AttributeError: + pass + else: + if isinstance(value, int): + return int(value) + elif isinstance(value, long): + return int(long(value)) + raise TypeError('__int__ method should return an integer') + raise TypeError('an integer is required') + raise TypeError('integer argument expected, got float') + +def _check_date_fields(year, month, day): + year = _check_int_field(year) + month = _check_int_field(month) + day = _check_int_field(day) + if not MINYEAR <= year <= MAXYEAR: + raise ValueError('year must be in %d..%d' % (MINYEAR, MAXYEAR), year) + if not 1 <= month <= 12: + raise ValueError('month must be in 1..12', month) + dim = _days_in_month(year, month) + if not 1 <= day <= dim: + raise ValueError('day must be in 1..%d' % dim, day) + return year, month, day + +def _check_time_fields(hour, minute, second, microsecond): + hour = _check_int_field(hour) + minute = _check_int_field(minute) + second = _check_int_field(second) + microsecond = _check_int_field(microsecond) + if not 0 <= hour <= 23: + raise ValueError('hour must be in 0..23', hour) + if not 0 <= minute <= 59: + raise ValueError('minute must be in 0..59', minute) + if not 0 <= second <= 59: + raise ValueError('second must be in 0..59', second) + if not 0 <= microsecond <= 999999: + raise ValueError('microsecond must be in 0..999999', microsecond) + return hour, minute, second, microsecond + +def _check_tzinfo_arg(tz): + if tz is not None and not isinstance(tz, tzinfo): + raise TypeError("tzinfo argument must be None or of a tzinfo subclass") + + +# Notes on comparison: In general, datetime module comparison operators raise +# TypeError when they don't know how to do a comparison themself. If they +# returned NotImplemented instead, comparison could (silently) fall back to +# the default compare-objects-by-comparing-their-memory-addresses strategy, +# and that's not helpful. There are two exceptions: +# +# 1. For date and datetime, if the other object has a "timetuple" attr, +# NotImplemented is returned. This is a hook to allow other kinds of +# datetime-like objects a chance to intercept the comparison. +# +# 2. Else __eq__ and __ne__ return False and True, respectively. This is +# so opertaions like +# +# x == y +# x != y +# x in sequence +# x not in sequence +# dict[x] = y +# +# don't raise annoying TypeErrors just because a datetime object +# is part of a heterogeneous collection. If there's no known way to +# compare X to a datetime, saying they're not equal is reasonable. + +def _cmperror(x, y): + raise TypeError("can't compare '%s' to '%s'" % ( + type(x).__name__, type(y).__name__)) + +def _normalize_pair(hi, lo, factor): + if not 0 <= lo <= factor-1: + inc, lo = divmod(lo, factor) + hi += inc + return hi, lo + +def _normalize_datetime(y, m, d, hh, mm, ss, us, ignore_overflow=False): + # Normalize all the inputs, and store the normalized values. + ss, us = _normalize_pair(ss, us, 1000000) + mm, ss = _normalize_pair(mm, ss, 60) + hh, mm = _normalize_pair(hh, mm, 60) + d, hh = _normalize_pair(d, hh, 24) + y, m, d = _normalize_date(y, m, d, ignore_overflow) + return y, m, d, hh, mm, ss, us + +def _normalize_date(year, month, day, ignore_overflow=False): + # That was easy. Now it gets muddy: the proper range for day + # can't be determined without knowing the correct month and year, + # but if day is, e.g., plus or minus a million, the current month + # and year values make no sense (and may also be out of bounds + # themselves). + # Saying 12 months == 1 year should be non-controversial. + if not 1 <= month <= 12: + year, month = _normalize_pair(year, month-1, 12) + month += 1 + assert 1 <= month <= 12 + + # Now only day can be out of bounds (year may also be out of bounds + # for a datetime object, but we don't care about that here). + # If day is out of bounds, what to do is arguable, but at least the + # method here is principled and explainable. + dim = _days_in_month(year, month) + if not 1 <= day <= dim: + # Move day-1 days from the first of the month. First try to + # get off cheap if we're only one day out of range (adjustments + # for timezone alone can't be worse than that). + if day == 0: # move back a day + month -= 1 + if month > 0: + day = _days_in_month(year, month) + else: + year, month, day = year-1, 12, 31 + elif day == dim + 1: # move forward a day + month += 1 + day = 1 + if month > 12: + month = 1 + year += 1 + else: + ordinal = _ymd2ord(year, month, 1) + (day - 1) + year, month, day = _ord2ymd(ordinal) + + if not ignore_overflow and not MINYEAR <= year <= MAXYEAR: + raise OverflowError("date value out of range") + return year, month, day + +def _accum(tag, sofar, num, factor, leftover): + if isinstance(num, (int, long)): + prod = num * factor + rsum = sofar + prod + return rsum, leftover + if isinstance(num, float): + fracpart, intpart = _math.modf(num) + prod = int(intpart) * factor + rsum = sofar + prod + if fracpart == 0.0: + return rsum, leftover + assert isinstance(factor, (int, long)) + fracpart, intpart = _math.modf(factor * fracpart) + rsum += int(intpart) + return rsum, leftover + fracpart + raise TypeError("unsupported type for timedelta %s component: %s" % + (tag, type(num))) + +class timedelta(object): + """Represent the difference between two datetime objects. + + Supported operators: + + - add, subtract timedelta + - unary plus, minus, abs + - compare to timedelta + - multiply, divide by int/long + + In addition, datetime supports subtraction of two datetime objects + returning a timedelta, and addition or subtraction of a datetime + and a timedelta giving a datetime. + + Representation: (days, seconds, microseconds). Why? Because I + felt like it. + """ + __slots__ = '_days', '_seconds', '_microseconds', '_hashcode' + + def __new__(cls, days=_SENTINEL, seconds=_SENTINEL, microseconds=_SENTINEL, + milliseconds=_SENTINEL, minutes=_SENTINEL, hours=_SENTINEL, weeks=_SENTINEL): + x = 0 + leftover = 0.0 + if microseconds is not _SENTINEL: + x, leftover = _accum("microseconds", x, microseconds, _US_PER_US, leftover) + if milliseconds is not _SENTINEL: + x, leftover = _accum("milliseconds", x, milliseconds, _US_PER_MS, leftover) + if seconds is not _SENTINEL: + x, leftover = _accum("seconds", x, seconds, _US_PER_SECOND, leftover) + if minutes is not _SENTINEL: + x, leftover = _accum("minutes", x, minutes, _US_PER_MINUTE, leftover) + if hours is not _SENTINEL: + x, leftover = _accum("hours", x, hours, _US_PER_HOUR, leftover) + if days is not _SENTINEL: + x, leftover = _accum("days", x, days, _US_PER_DAY, leftover) + if weeks is not _SENTINEL: + x, leftover = _accum("weeks", x, weeks, _US_PER_WEEK, leftover) + if leftover != 0.0: + x += _round(leftover) + return cls._from_microseconds(x) + + @classmethod + def _from_microseconds(cls, us): + s, us = divmod(us, _US_PER_SECOND) + d, s = divmod(s, _SECONDS_PER_DAY) + return cls._create(d, s, us, False) + + @classmethod + def _create(cls, d, s, us, normalize): + if normalize: + s, us = _normalize_pair(s, us, 1000000) + d, s = _normalize_pair(d, s, 24*3600) + + if not -_MAX_DELTA_DAYS <= d <= _MAX_DELTA_DAYS: + raise OverflowError("days=%d; must have magnitude <= %d" % (d, _MAX_DELTA_DAYS)) + + self = object.__new__(cls) + self._days = d + self._seconds = s + self._microseconds = us + self._hashcode = -1 + return self + + def _to_microseconds(self): + return ((self._days * _SECONDS_PER_DAY + self._seconds) * _US_PER_SECOND + + self._microseconds) + + def __repr__(self): + module = "datetime." if self.__class__ is timedelta else "" + if self._microseconds: + return "%s(%d, %d, %d)" % (module + self.__class__.__name__, + self._days, + self._seconds, + self._microseconds) + if self._seconds: + return "%s(%d, %d)" % (module + self.__class__.__name__, + self._days, + self._seconds) + return "%s(%d)" % (module + self.__class__.__name__, self._days) + + def __str__(self): + mm, ss = divmod(self._seconds, 60) + hh, mm = divmod(mm, 60) + s = "%d:%02d:%02d" % (hh, mm, ss) + if self._days: + def plural(n): + return n, abs(n) != 1 and "s" or "" + s = ("%d day%s, " % plural(self._days)) + s + if self._microseconds: + s = s + ".%06d" % self._microseconds + return s + + def total_seconds(self): + """Total seconds in the duration.""" + # return self._to_microseconds() / 10**6 + return float(self._to_microseconds()) / float(10**6) + + # Read-only field accessors + @property + def days(self): + """days""" + return self._days + + @property + def seconds(self): + """seconds""" + return self._seconds + + @property + def microseconds(self): + """microseconds""" + return self._microseconds + + def __add__(self, other): + if isinstance(other, timedelta): + # for CPython compatibility, we cannot use + # our __class__ here, but need a real timedelta + return timedelta._create(self._days + other._days, + self._seconds + other._seconds, + self._microseconds + other._microseconds, + True) + return NotImplemented + + def __sub__(self, other): + if isinstance(other, timedelta): + # for CPython compatibility, we cannot use + # our __class__ here, but need a real timedelta + return timedelta._create(self._days - other._days, + self._seconds - other._seconds, + self._microseconds - other._microseconds, + True) + return NotImplemented + + def __neg__(self): + # for CPython compatibility, we cannot use + # our __class__ here, but need a real timedelta + return timedelta._create(-self._days, + -self._seconds, + -self._microseconds, + True) + + def __pos__(self): + # for CPython compatibility, we cannot use + # our __class__ here, but need a real timedelta + return timedelta._create(self._days, + self._seconds, + self._microseconds, + False) + + def __abs__(self): + if self._days < 0: + return -self + else: + return self + + def __mul__(self, other): + if not isinstance(other, (int, long)): + return NotImplemented + usec = self._to_microseconds() + return timedelta._from_microseconds(usec * other) + + __rmul__ = __mul__ + + def __div__(self, other): + if not isinstance(other, (int, long)): + return NotImplemented + usec = self._to_microseconds() + # return timedelta._from_microseconds(usec // other) + return timedelta._from_microseconds(int(usec) / int(other)) + + __floordiv__ = __div__ + + # Comparisons of timedelta objects with other. + + def __eq__(self, other): + if isinstance(other, timedelta): + return self._cmp(other) == 0 + else: + return False + + def __ne__(self, other): + if isinstance(other, timedelta): + return self._cmp(other) != 0 + else: + return True + + def __le__(self, other): + if isinstance(other, timedelta): + return self._cmp(other) <= 0 + else: + _cmperror(self, other) + + def __lt__(self, other): + if isinstance(other, timedelta): + return self._cmp(other) < 0 + else: + _cmperror(self, other) + + def __ge__(self, other): + if isinstance(other, timedelta): + return self._cmp(other) >= 0 + else: + _cmperror(self, other) + + def __gt__(self, other): + if isinstance(other, timedelta): + return self._cmp(other) > 0 + else: + _cmperror(self, other) + + def _cmp(self, other): + assert isinstance(other, timedelta) + return _cmp(self._getstate(), other._getstate()) + + def __hash__(self): + if self._hashcode == -1: + self._hashcode = hash(self._getstate()) + return self._hashcode + + def __nonzero__(self): + return (self._days != 0 or + self._seconds != 0 or + self._microseconds != 0) + + # Pickle support. + + def _getstate(self): + return (self._days, self._seconds, self._microseconds) + + def __reduce__(self): + return (self.__class__, self._getstate()) + +timedelta.min = timedelta(-_MAX_DELTA_DAYS) +timedelta.max = timedelta(_MAX_DELTA_DAYS, 24*3600-1, 1000000-1) +timedelta.resolution = timedelta(microseconds=1) + +class date(object): + """Concrete date type. + + Constructors: + + __new__() + fromtimestamp() + today() + fromordinal() + + Operators: + + __repr__, __str__ + __cmp__, __hash__ + __add__, __radd__, __sub__ (add/radd only with timedelta arg) + + Methods: + + timetuple() + toordinal() + weekday() + isoweekday(), isocalendar(), isoformat() + ctime() + strftime() + + Properties (readonly): + year, month, day + """ + __slots__ = '_year', '_month', '_day', '_hashcode' + + def __new__(cls, year, month=None, day=None): + """Constructor. + + Arguments: + + year, month, day (required, base 1) + """ + # if month is None and isinstance(year, bytes) and len(year) == 4 and \ + # 1 <= ord(year[2]) <= 12: + # # Pickle support + # self = object.__new__(cls) + # self.__setstate(year) + # self._hashcode = -1 + # return self + year, month, day = _check_date_fields(year, month, day) + self = object.__new__(cls) + self._year = year + self._month = month + self._day = day + self._hashcode = -1 + return self + + # Additional constructors + + @classmethod + def fromtimestamp(cls, t): + "Construct a date from a POSIX timestamp (like time.time())." + y, m, d, hh, mm, ss, weekday, jday, dst = _time.localtime(t) + return cls(y, m, d) + + @classmethod + def today(cls): + "Construct a date from time.time()." + t = _time.time() + return cls.fromtimestamp(t) + + @classmethod + def fromordinal(cls, n): + """Contruct a date from a proleptic Gregorian ordinal. + + January 1 of year 1 is day 1. Only the year, month and day are + non-zero in the result. + """ + y, m, d = _ord2ymd(n) + return cls(y, m, d) + + # Conversions to string + + def __repr__(self): + """Convert to formal string, for repr(). + + >>> dt = datetime(2010, 1, 1) + >>> repr(dt) + 'datetime.datetime(2010, 1, 1, 0, 0)' + + >>> dt = datetime(2010, 1, 1, tzinfo=timezone.utc) + >>> repr(dt) + 'datetime.datetime(2010, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)' + """ + module = "datetime." if self.__class__ is date else "" + return "%s(%d, %d, %d)" % (module + self.__class__.__name__, + self._year, + self._month, + self._day) + + # XXX These shouldn't depend on time.localtime(), because that + # clips the usable dates to [1970 .. 2038). At least ctime() is + # easily done without using strftime() -- that's better too because + # strftime("%c", ...) is locale specific. + + def ctime(self): + "Return ctime() style string." + weekday = self.toordinal() % 7 or 7 + return "%s %s %2d 00:00:00 %04d" % ( + _DAYNAMES[weekday], + _MONTHNAMES[self._month], + self._day, self._year) + + # def strftime(self, format): + # "Format using strftime()." + # return _wrap_strftime(self, format, self.timetuple()) + + def __format__(self, fmt): + if not isinstance(fmt, (str, unicode)): + raise ValueError("__format__ expects str or unicode, not %s" % + fmt.__class__.__name__) + if len(fmt) != 0: + return self.strftime(fmt) + return str(self) + + def isoformat(self): + """Return the date formatted according to ISO. + + This is 'YYYY-MM-DD'. + + References: + - http://www.w3.org/TR/NOTE-datetime + - http://www.cl.cam.ac.uk/~mgk25/iso-time.html + """ + # return "%04d-%02d-%02d" % (self._year, self._month, self._day) + return "%s-%s-%s" % (str(self._year).zfill(4), str(self._month).zfill(2), str(self._day).zfill(2)) + + __str__ = isoformat + + # Read-only field accessors + @property + def year(self): + """year (1-9999)""" + return self._year + + @property + def month(self): + """month (1-12)""" + return self._month + + @property + def day(self): + """day (1-31)""" + return self._day + + # Standard conversions, __cmp__, __hash__ (and helpers) + + def timetuple(self): + "Return local time tuple compatible with time.localtime()." + return _build_struct_time(self._year, self._month, self._day, + 0, 0, 0, -1) + + def toordinal(self): + """Return proleptic Gregorian ordinal for the year, month and day. + + January 1 of year 1 is day 1. Only the year, month and day values + contribute to the result. + """ + return _ymd2ord(self._year, self._month, self._day) + + def replace(self, year=None, month=None, day=None): + """Return a new date with new values for the specified fields.""" + if year is None: + year = self._year + if month is None: + month = self._month + if day is None: + day = self._day + return date.__new__(type(self), year, month, day) + + # Comparisons of date objects with other. + + def __eq__(self, other): + if isinstance(other, date): + return self._cmp(other) == 0 + elif hasattr(other, "timetuple"): + return NotImplemented + else: + return False + + def __ne__(self, other): + if isinstance(other, date): + return self._cmp(other) != 0 + elif hasattr(other, "timetuple"): + return NotImplemented + else: + return True + + def __le__(self, other): + if isinstance(other, date): + return self._cmp(other) <= 0 + elif hasattr(other, "timetuple"): + return NotImplemented + else: + _cmperror(self, other) + + def __lt__(self, other): + if isinstance(other, date): + return self._cmp(other) < 0 + elif hasattr(other, "timetuple"): + return NotImplemented + else: + _cmperror(self, other) + + def __ge__(self, other): + if isinstance(other, date): + return self._cmp(other) >= 0 + elif hasattr(other, "timetuple"): + return NotImplemented + else: + _cmperror(self, other) + + def __gt__(self, other): + if isinstance(other, date): + return self._cmp(other) > 0 + elif hasattr(other, "timetuple"): + return NotImplemented + else: + _cmperror(self, other) + + def _cmp(self, other): + assert isinstance(other, date) + y, m, d = self._year, self._month, self._day + y2, m2, d2 = other._year, other._month, other._day + return _cmp((y, m, d), (y2, m2, d2)) + + def __hash__(self): + "Hash." + if self._hashcode == -1: + self._hashcode = hash(self._getstate()) + return self._hashcode + + # Computations + + def _add_timedelta(self, other, factor): + y, m, d = _normalize_date( + self._year, + self._month, + self._day + other.days * factor) + return date(y, m, d) + + def __add__(self, other): + "Add a date to a timedelta." + if isinstance(other, timedelta): + return self._add_timedelta(other, 1) + return NotImplemented + + __radd__ = __add__ + + def __sub__(self, other): + """Subtract two dates, or a date and a timedelta.""" + if isinstance(other, date): + days1 = self.toordinal() + days2 = other.toordinal() + return timedelta._create(days1 - days2, 0, 0, False) + if isinstance(other, timedelta): + return self._add_timedelta(other, -1) + return NotImplemented + + def weekday(self): + "Return day of the week, where Monday == 0 ... Sunday == 6." + return (self.toordinal() + 6) % 7 + + # Day-of-the-week and week-of-the-year, according to ISO + + def isoweekday(self): + "Return day of the week, where Monday == 1 ... Sunday == 7." + # 1-Jan-0001 is a Monday + return self.toordinal() % 7 or 7 + + def isocalendar(self): + """Return a 3-tuple containing ISO year, week number, and weekday. + + The first ISO week of the year is the (Mon-Sun) week + containing the year's first Thursday; everything else derives + from that. + + The first week is 1; Monday is 1 ... Sunday is 7. + + ISO calendar algorithm taken from + http://www.phys.uu.nl/~vgent/calendar/isocalendar.htm + """ + year = self._year + week1monday = _isoweek1monday(year) + today = _ymd2ord(self._year, self._month, self._day) + # Internally, week and day have origin 0 + week, day = divmod(today - week1monday, 7) + if week < 0: + year -= 1 + week1monday = _isoweek1monday(year) + week, day = divmod(today - week1monday, 7) + elif week >= 52: + if today >= _isoweek1monday(year+1): + year += 1 + week = 0 + return year, week+1, day+1 + + # Pickle support. + + def _getstate(self): + yhi, ylo = divmod(self._year, 256) + return (_struct.pack('4B', yhi, ylo, self._month, self._day),) + + def __setstate(self, string): + yhi, ylo, self._month, self._day = (ord(string[0]), ord(string[1]), + ord(string[2]), ord(string[3])) + self._year = yhi * 256 + ylo + + def __reduce__(self): + return (self.__class__, self._getstate()) + +_date_class = date # so functions w/ args named "date" can get at the class + +date.min = date(1, 1, 1) +date.max = date(9999, 12, 31) +date.resolution = timedelta(days=1) + +class tzinfo(object): + """Abstract base class for time zone info classes. + + Subclasses must override the name(), utcoffset() and dst() methods. + """ + __slots__ = () + + def tzname(self, dt): + "datetime -> string name of time zone." + raise NotImplementedError("tzinfo subclass must override tzname()") + + def utcoffset(self, dt): + "datetime -> minutes east of UTC (negative for west of UTC)" + raise NotImplementedError("tzinfo subclass must override utcoffset()") + + def dst(self, dt): + """datetime -> DST offset in minutes east of UTC. + + Return 0 if DST not in effect. utcoffset() must include the DST + offset. + """ + raise NotImplementedError("tzinfo subclass must override dst()") + + def fromutc(self, dt): + "datetime in UTC -> datetime in local time." + + if not isinstance(dt, datetime): + raise TypeError("fromutc() requires a datetime argument") + if dt.tzinfo is not self: + raise ValueError("dt.tzinfo is not self") + + dtoff = dt.utcoffset() + if dtoff is None: + raise ValueError("fromutc() requires a non-None utcoffset() " + "result") + + # See the long comment block at the end of this file for an + # explanation of this algorithm. + dtdst = dt.dst() + if dtdst is None: + raise ValueError("fromutc() requires a non-None dst() result") + delta = dtoff - dtdst + if delta: + dt += delta + dtdst = dt.dst() + if dtdst is None: + raise ValueError("fromutc(): dt.dst gave inconsistent " + "results; cannot convert") + if dtdst: + return dt + dtdst + else: + return dt + + # Pickle support. + + def __reduce__(self): + getinitargs = getattr(self, "__getinitargs__", None) + if getinitargs: + args = getinitargs() + else: + args = () + getstate = getattr(self, "__getstate__", None) + if getstate: + state = getstate() + else: + state = getattr(self, "__dict__", None) or None + if state is None: + return (self.__class__, args) + else: + return (self.__class__, args, state) + +_tzinfo_class = tzinfo + +class time(object): + """Time with time zone. + + Constructors: + + __new__() + + Operators: + + __repr__, __str__ + __cmp__, __hash__ + + Methods: + + strftime() + isoformat() + utcoffset() + tzname() + dst() + + Properties (readonly): + hour, minute, second, microsecond, tzinfo + """ + __slots__ = '_hour', '_minute', '_second', '_microsecond', '_tzinfo', '_hashcode' + + def __new__(cls, hour=0, minute=0, second=0, microsecond=0, tzinfo=None): + """Constructor. + + Arguments: + + hour, minute (required) + second, microsecond (default to zero) + tzinfo (default to None) + """ + # if isinstance(hour, bytes) and len(hour) == 6 and ord(hour[0]) < 24: + # # Pickle support + # self = object.__new__(cls) + # self.__setstate(hour, minute or None) + # self._hashcode = -1 + # return self + hour, minute, second, microsecond = _check_time_fields( + hour, minute, second, microsecond) + _check_tzinfo_arg(tzinfo) + self = object.__new__(cls) + self._hour = hour + self._minute = minute + self._second = second + self._microsecond = microsecond + self._tzinfo = tzinfo + self._hashcode = -1 + return self + + # Read-only field accessors + @property + def hour(self): + """hour (0-23)""" + return self._hour + + @property + def minute(self): + """minute (0-59)""" + return self._minute + + @property + def second(self): + """second (0-59)""" + return self._second + + @property + def microsecond(self): + """microsecond (0-999999)""" + return self._microsecond + + @property + def tzinfo(self): + """timezone info object""" + return self._tzinfo + + # Standard conversions, __hash__ (and helpers) + + # Comparisons of time objects with other. + + def __eq__(self, other): + if isinstance(other, time): + return self._cmp(other) == 0 + else: + return False + + def __ne__(self, other): + if isinstance(other, time): + return self._cmp(other) != 0 + else: + return True + + def __le__(self, other): + if isinstance(other, time): + return self._cmp(other) <= 0 + else: + _cmperror(self, other) + + def __lt__(self, other): + if isinstance(other, time): + return self._cmp(other) < 0 + else: + _cmperror(self, other) + + def __ge__(self, other): + if isinstance(other, time): + return self._cmp(other) >= 0 + else: + _cmperror(self, other) + + def __gt__(self, other): + if isinstance(other, time): + return self._cmp(other) > 0 + else: + _cmperror(self, other) + + def _cmp(self, other): + assert isinstance(other, time) + mytz = self._tzinfo + ottz = other._tzinfo + myoff = otoff = None + + if mytz is ottz: + base_compare = True + else: + myoff = self._utcoffset() + otoff = other._utcoffset() + base_compare = myoff == otoff + + if base_compare: + return _cmp((self._hour, self._minute, self._second, + self._microsecond), + (other._hour, other._minute, other._second, + other._microsecond)) + if myoff is None or otoff is None: + raise TypeError("can't compare offset-naive and offset-aware times") + myhhmm = self._hour * 60 + self._minute - myoff + othhmm = other._hour * 60 + other._minute - otoff + return _cmp((myhhmm, self._second, self._microsecond), + (othhmm, other._second, other._microsecond)) + + def __hash__(self): + """Hash.""" + if self._hashcode == -1: + tzoff = self._utcoffset() + if not tzoff: # zero or None + self._hashcode = hash(self._getstate()[0]) + else: + h, m = divmod(self.hour * 60 + self.minute - tzoff, 60) + if 0 <= h < 24: + self._hashcode = hash(time(h, m, self.second, self.microsecond)) + else: + self._hashcode = hash((h, m, self.second, self.microsecond)) + return self._hashcode + + # Conversion to string + + def _tzstr(self, sep=":"): + """Return formatted timezone offset (+xx:xx) or None.""" + off = self._utcoffset() + if off is not None: + if off < 0: + sign = "-" + off = -off + else: + sign = "+" + hh, mm = divmod(off, 60) + assert 0 <= hh < 24 + off = "%s%02d%s%02d" % (sign, hh, sep, mm) + return off + + def __repr__(self): + """Convert to formal string, for repr().""" + if self._microsecond != 0: + s = ", %d, %d" % (self._second, self._microsecond) + elif self._second != 0: + s = ", %d" % self._second + else: + s = "" + module = "datetime." if self.__class__ is time else "" + s= "%s(%d, %d%s)" % (module + self.__class__.__name__, + self._hour, self._minute, s) + if self._tzinfo is not None: + assert s[-1:] == ")" + s = s[:-1] + ", tzinfo=%r" % self._tzinfo + ")" + return s + + def isoformat(self): + """Return the time formatted according to ISO. + + This is 'HH:MM:SS.mmmmmm+zz:zz', or 'HH:MM:SS+zz:zz' if + self.microsecond == 0. + """ + s = _format_time(self._hour, self._minute, self._second, + self._microsecond) + tz = self._tzstr() + if tz: + s += tz + return s + + __str__ = isoformat + + # def strftime(self, format): + # """Format using strftime(). The date part of the timestamp passed + # to underlying strftime should not be used. + # """ + # # The year must be >= _MINYEARFMT else Python's strftime implementation + # # can raise a bogus exception. + # timetuple = (1900, 1, 1, + # self._hour, self._minute, self._second, + # 0, 1, -1) + # return _wrap_strftime(self, format, timetuple) + + def __format__(self, fmt): + if not isinstance(fmt, (str, unicode)): + raise ValueError("__format__ expects str or unicode, not %s" % + fmt.__class__.__name__) + if len(fmt) != 0: + return self.strftime(fmt) + return str(self) + + # Timezone functions + + def utcoffset(self): + """Return the timezone offset in minutes east of UTC (negative west of + UTC).""" + if self._tzinfo is None: + return None + offset = self._tzinfo.utcoffset(None) + offset = _check_utc_offset("utcoffset", offset) + if offset is not None: + offset = timedelta._create(0, offset * 60, 0, True) + return offset + + # Return an integer (or None) instead of a timedelta (or None). + def _utcoffset(self): + if self._tzinfo is None: + return None + offset = self._tzinfo.utcoffset(None) + offset = _check_utc_offset("utcoffset", offset) + return offset + + def tzname(self): + """Return the timezone name. + + Note that the name is 100% informational -- there's no requirement that + it mean anything in particular. For example, "GMT", "UTC", "-500", + "-5:00", "EDT", "US/Eastern", "America/New York" are all valid replies. + """ + if self._tzinfo is None: + return None + name = self._tzinfo.tzname(None) + _check_tzname(name) + return name + + def dst(self): + """Return 0 if DST is not in effect, or the DST offset (in minutes + eastward) if DST is in effect. + + This is purely informational; the DST offset has already been added to + the UTC offset returned by utcoffset() if applicable, so there's no + need to consult dst() unless you're interested in displaying the DST + info. + """ + if self._tzinfo is None: + return None + offset = self._tzinfo.dst(None) + offset = _check_utc_offset("dst", offset) + if offset is not None: + offset = timedelta._create(0, offset * 60, 0, True) + return offset + + # Return an integer (or None) instead of a timedelta (or None). + def _dst(self): + if self._tzinfo is None: + return None + offset = self._tzinfo.dst(None) + offset = _check_utc_offset("dst", offset) + return offset + + def replace(self, hour=None, minute=None, second=None, microsecond=None, + tzinfo=True): + """Return a new time with new values for the specified fields.""" + if hour is None: + hour = self.hour + if minute is None: + minute = self.minute + if second is None: + second = self.second + if microsecond is None: + microsecond = self.microsecond + if tzinfo is True: + tzinfo = self.tzinfo + return time.__new__(type(self), + hour, minute, second, microsecond, tzinfo) + + def __nonzero__(self): + if self.second or self.microsecond: + return True + offset = self._utcoffset() or 0 + return self.hour * 60 + self.minute != offset + + # Pickle support. + + def _getstate(self): + us2, us3 = divmod(self._microsecond, 256) + us1, us2 = divmod(us2, 256) + basestate = _struct.pack('6B', self._hour, self._minute, self._second, + us1, us2, us3) + if self._tzinfo is None: + return (basestate,) + else: + return (basestate, self._tzinfo) + + def __setstate(self, string, tzinfo): + if tzinfo is not None and not isinstance(tzinfo, _tzinfo_class): + raise TypeError("bad tzinfo state arg") + self._hour, self._minute, self._second, us1, us2, us3 = ( + ord(string[0]), ord(string[1]), ord(string[2]), + ord(string[3]), ord(string[4]), ord(string[5])) + self._microsecond = (((us1 << 8) | us2) << 8) | us3 + self._tzinfo = tzinfo + + def __reduce__(self): + return (time, self._getstate()) + +_time_class = time # so functions w/ args named "time" can get at the class + +time.min = time(0, 0, 0) +time.max = time(23, 59, 59, 999999) +time.resolution = timedelta(microseconds=1) + +class datetime(date): + """datetime(year, month, day[, hour[, minute[, second[, microsecond[,tzinfo]]]]]) + + The year, month and day arguments are required. tzinfo may be None, or an + instance of a tzinfo subclass. The remaining arguments may be ints or longs. + """ + __slots__ = date.__slots__ + time.__slots__ + + def __new__(cls, year, month=None, day=None, hour=0, minute=0, second=0, + microsecond=0, tzinfo=None): + # if isinstance(year, bytes) and len(year) == 10 and \ + # 1 <= ord(year[2]) <= 12: + # # Pickle support + # self = object.__new__(cls) + # self.__setstate(year, month) + # self._hashcode = -1 + # return self + year, month, day = _check_date_fields(year, month, day) + hour, minute, second, microsecond = _check_time_fields( + hour, minute, second, microsecond) + _check_tzinfo_arg(tzinfo) + self = object.__new__(cls) + self._year = year + self._month = month + self._day = day + self._hour = hour + self._minute = minute + self._second = second + self._microsecond = microsecond + self._tzinfo = tzinfo + self._hashcode = -1 + return self + + # Read-only field accessors + @property + def hour(self): + """hour (0-23)""" + return self._hour + + @property + def minute(self): + """minute (0-59)""" + return self._minute + + @property + def second(self): + """second (0-59)""" + return self._second + + @property + def microsecond(self): + """microsecond (0-999999)""" + return self._microsecond + + @property + def tzinfo(self): + """timezone info object""" + return self._tzinfo + + @classmethod + def fromtimestamp(cls, timestamp, tz=None): + """Construct a datetime from a POSIX timestamp (like time.time()). + + A timezone info object may be passed in as well. + """ + _check_tzinfo_arg(tz) + converter = _time.localtime if tz is None else _time.gmtime + self = cls._from_timestamp(converter, timestamp, tz) + if tz is not None: + self = tz.fromutc(self) + return self + + @classmethod + def utcfromtimestamp(cls, t): + "Construct a UTC datetime from a POSIX timestamp (like time.time())." + return cls._from_timestamp(_time.gmtime, t, None) + + @classmethod + def _from_timestamp(cls, converter, timestamp, tzinfo): + t_full = timestamp + timestamp = int(_math.floor(timestamp)) + frac = t_full - timestamp + us = _round(frac * 1e6) + + # If timestamp is less than one microsecond smaller than a + # full second, us can be rounded up to 1000000. In this case, + # roll over to seconds, otherwise, ValueError is raised + # by the constructor. + if us == 1000000: + timestamp += 1 + us = 0 + y, m, d, hh, mm, ss, weekday, jday, dst = converter(timestamp) + ss = min(ss, 59) # clamp out leap seconds if the platform has them + return cls(y, m, d, hh, mm, ss, us, tzinfo) + + @classmethod + def now(cls, tz=None): + "Construct a datetime from time.time() and optional time zone info." + t = _time.time() + return cls.fromtimestamp(t, tz) + + @classmethod + def utcnow(cls): + "Construct a UTC datetime from time.time()." + t = _time.time() + return cls.utcfromtimestamp(t) + + @classmethod + def combine(cls, date, time): + "Construct a datetime from a given date and a given time." + if not isinstance(date, _date_class): + raise TypeError("date argument must be a date instance") + if not isinstance(time, _time_class): + raise TypeError("time argument must be a time instance") + return cls(date.year, date.month, date.day, + time.hour, time.minute, time.second, time.microsecond, + time.tzinfo) + + def timetuple(self): + "Return local time tuple compatible with time.localtime()." + dst = self._dst() + if dst is None: + dst = -1 + elif dst: + dst = 1 + return _build_struct_time(self.year, self.month, self.day, + self.hour, self.minute, self.second, + dst) + + def utctimetuple(self): + "Return UTC time tuple compatible with time.gmtime()." + y, m, d = self.year, self.month, self.day + hh, mm, ss = self.hour, self.minute, self.second + offset = self._utcoffset() + if offset: # neither None nor 0 + mm -= offset + y, m, d, hh, mm, ss, _ = _normalize_datetime( + y, m, d, hh, mm, ss, 0, ignore_overflow=True) + return _build_struct_time(y, m, d, hh, mm, ss, 0) + + def date(self): + "Return the date part." + return date(self._year, self._month, self._day) + + def time(self): + "Return the time part, with tzinfo None." + return time(self.hour, self.minute, self.second, self.microsecond) + + def timetz(self): + "Return the time part, with same tzinfo." + return time(self.hour, self.minute, self.second, self.microsecond, + self._tzinfo) + + def replace(self, year=None, month=None, day=None, hour=None, + minute=None, second=None, microsecond=None, tzinfo=True): + """Return a new datetime with new values for the specified fields.""" + if year is None: + year = self.year + if month is None: + month = self.month + if day is None: + day = self.day + if hour is None: + hour = self.hour + if minute is None: + minute = self.minute + if second is None: + second = self.second + if microsecond is None: + microsecond = self.microsecond + if tzinfo is True: + tzinfo = self.tzinfo + return datetime.__new__(type(self), + year, month, day, hour, minute, second, + microsecond, tzinfo) + + def astimezone(self, tz): + if not isinstance(tz, tzinfo): + raise TypeError("tz argument must be an instance of tzinfo") + + mytz = self.tzinfo + if mytz is None: + raise ValueError("astimezone() requires an aware datetime") + + if tz is mytz: + return self + + # Convert self to UTC, and attach the new time zone object. + myoffset = self.utcoffset() + if myoffset is None: + raise ValueError("astimezone() requires an aware datetime") + utc = (self - myoffset).replace(tzinfo=tz) + + # Convert from UTC to tz's local time. + return tz.fromutc(utc) + + # Ways to produce a string. + + def ctime(self): + "Return ctime() style string." + weekday = self.toordinal() % 7 or 7 + return "%s %s %2d %02d:%02d:%02d %04d" % ( + _DAYNAMES[weekday], + _MONTHNAMES[self._month], + self._day, + self._hour, self._minute, self._second, + self._year) + + def isoformat(self, sep='T'): + """Return the time formatted according to ISO. + + This is 'YYYY-MM-DD HH:MM:SS.mmmmmm', or 'YYYY-MM-DD HH:MM:SS' if + self.microsecond == 0. + + If self.tzinfo is not None, the UTC offset is also attached, giving + 'YYYY-MM-DD HH:MM:SS.mmmmmm+HH:MM' or 'YYYY-MM-DD HH:MM:SS+HH:MM'. + + Optional argument sep specifies the separator between date and + time, default 'T'. + """ + s = ("%04d-%02d-%02d%c" % (self._year, self._month, self._day, sep) + + _format_time(self._hour, self._minute, self._second, + self._microsecond)) + off = self._utcoffset() + if off is not None: + if off < 0: + sign = "-" + off = -off + else: + sign = "+" + hh, mm = divmod(off, 60) + s += "%s%02d:%02d" % (sign, hh, mm) + return s + + def __repr__(self): + """Convert to formal string, for repr().""" + L = [self._year, self._month, self._day, # These are never zero + self._hour, self._minute, self._second, self._microsecond] + if L[-1] == 0: + del L[-1] + if L[-1] == 0: + del L[-1] + s = ", ".join(map(str, L)) + module = "datetime." if self.__class__ is datetime else "" + s = "%s(%s)" % (module + self.__class__.__name__, s) + if self._tzinfo is not None: + assert s[-1:] == ")" + s = s[:-1] + ", tzinfo=%r" % self._tzinfo + ")" + return s + + def __str__(self): + "Convert to string, for str()." + return self.isoformat(sep=' ') + + # @classmethod + # def strptime(cls, date_string, format): + # 'string, format -> new datetime parsed from a string (like time.strptime()).' + # from _strptime import _strptime + # # _strptime._strptime returns a two-element tuple. The first + # # element is a time.struct_time object. The second is the + # # microseconds (which are not defined for time.struct_time). + # struct, micros = _strptime(date_string, format) + # return cls(*(struct[0:6] + (micros,))) + + def utcoffset(self): + """Return the timezone offset in minutes east of UTC (negative west of + UTC).""" + if self._tzinfo is None: + return None + offset = self._tzinfo.utcoffset(self) + offset = _check_utc_offset("utcoffset", offset) + if offset is not None: + offset = timedelta._create(0, offset * 60, 0, True) + return offset + + # Return an integer (or None) instead of a timedelta (or None). + def _utcoffset(self): + if self._tzinfo is None: + return None + offset = self._tzinfo.utcoffset(self) + offset = _check_utc_offset("utcoffset", offset) + return offset + + def tzname(self): + """Return the timezone name. + + Note that the name is 100% informational -- there's no requirement that + it mean anything in particular. For example, "GMT", "UTC", "-500", + "-5:00", "EDT", "US/Eastern", "America/New York" are all valid replies. + """ + if self._tzinfo is None: + return None + name = self._tzinfo.tzname(self) + _check_tzname(name) + return name + + def dst(self): + """Return 0 if DST is not in effect, or the DST offset (in minutes + eastward) if DST is in effect. + + This is purely informational; the DST offset has already been added to + the UTC offset returned by utcoffset() if applicable, so there's no + need to consult dst() unless you're interested in displaying the DST + info. + """ + if self._tzinfo is None: + return None + offset = self._tzinfo.dst(self) + offset = _check_utc_offset("dst", offset) + if offset is not None: + offset = timedelta._create(0, offset * 60, 0, True) + return offset + + # Return an integer (or None) instead of a timedelta (or None). + def _dst(self): + if self._tzinfo is None: + return None + offset = self._tzinfo.dst(self) + offset = _check_utc_offset("dst", offset) + return offset + + # Comparisons of datetime objects with other. + + def __eq__(self, other): + if isinstance(other, datetime): + return self._cmp(other) == 0 + elif hasattr(other, "timetuple") and not isinstance(other, date): + return NotImplemented + else: + return False + + def __ne__(self, other): + if isinstance(other, datetime): + return self._cmp(other) != 0 + elif hasattr(other, "timetuple") and not isinstance(other, date): + return NotImplemented + else: + return True + + def __le__(self, other): + if isinstance(other, datetime): + return self._cmp(other) <= 0 + elif hasattr(other, "timetuple") and not isinstance(other, date): + return NotImplemented + else: + _cmperror(self, other) + + def __lt__(self, other): + if isinstance(other, datetime): + return self._cmp(other) < 0 + elif hasattr(other, "timetuple") and not isinstance(other, date): + return NotImplemented + else: + _cmperror(self, other) + + def __ge__(self, other): + if isinstance(other, datetime): + return self._cmp(other) >= 0 + elif hasattr(other, "timetuple") and not isinstance(other, date): + return NotImplemented + else: + _cmperror(self, other) + + def __gt__(self, other): + if isinstance(other, datetime): + return self._cmp(other) > 0 + elif hasattr(other, "timetuple") and not isinstance(other, date): + return NotImplemented + else: + _cmperror(self, other) + + def _cmp(self, other): + assert isinstance(other, datetime) + mytz = self._tzinfo + ottz = other._tzinfo + myoff = otoff = None + + if mytz is ottz: + base_compare = True + else: + if mytz is not None: + myoff = self._utcoffset() + if ottz is not None: + otoff = other._utcoffset() + base_compare = myoff == otoff + + if base_compare: + return _cmp((self._year, self._month, self._day, + self._hour, self._minute, self._second, + self._microsecond), + (other._year, other._month, other._day, + other._hour, other._minute, other._second, + other._microsecond)) + if myoff is None or otoff is None: + raise TypeError("can't compare offset-naive and offset-aware datetimes") + # XXX What follows could be done more efficiently... + diff = self - other # this will take offsets into account + if diff.days < 0: + return -1 + return diff and 1 or 0 + + def _add_timedelta(self, other, factor): + y, m, d, hh, mm, ss, us = _normalize_datetime( + self._year, + self._month, + self._day + other.days * factor, + self._hour, + self._minute, + self._second + other.seconds * factor, + self._microsecond + other.microseconds * factor) + return datetime(y, m, d, hh, mm, ss, us, tzinfo=self._tzinfo) + + def __add__(self, other): + "Add a datetime and a timedelta." + if not isinstance(other, timedelta): + return NotImplemented + return self._add_timedelta(other, 1) + + __radd__ = __add__ + + def __sub__(self, other): + "Subtract two datetimes, or a datetime and a timedelta." + if not isinstance(other, datetime): + if isinstance(other, timedelta): + return self._add_timedelta(other, -1) + return NotImplemented + + delta_d = self.toordinal() - other.toordinal() + delta_s = (self._hour - other._hour) * 3600 + \ + (self._minute - other._minute) * 60 + \ + (self._second - other._second) + delta_us = self._microsecond - other._microsecond + base = timedelta._create(delta_d, delta_s, delta_us, True) + if self._tzinfo is other._tzinfo: + return base + myoff = self._utcoffset() + otoff = other._utcoffset() + if myoff == otoff: + return base + if myoff is None or otoff is None: + raise TypeError("can't subtract offset-naive and offset-aware datetimes") + return base + timedelta(minutes = otoff-myoff) + + def __hash__(self): + if self._hashcode == -1: + tzoff = self._utcoffset() + if tzoff is None: + self._hashcode = hash(self._getstate()[0]) + else: + days = _ymd2ord(self.year, self.month, self.day) + seconds = self.hour * 3600 + (self.minute - tzoff) * 60 + self.second + self._hashcode = hash(timedelta(days, seconds, self.microsecond)) + return self._hashcode + + # Pickle support. + + def _getstate(self): + yhi, ylo = divmod(self._year, 256) + us2, us3 = divmod(self._microsecond, 256) + us1, us2 = divmod(us2, 256) + basestate = _struct.pack('10B', yhi, ylo, self._month, self._day, + self._hour, self._minute, self._second, + us1, us2, us3) + if self._tzinfo is None: + return (basestate,) + else: + return (basestate, self._tzinfo) + + def __setstate(self, string, tzinfo): + if tzinfo is not None and not isinstance(tzinfo, _tzinfo_class): + raise TypeError("bad tzinfo state arg") + (yhi, ylo, self._month, self._day, self._hour, + self._minute, self._second, us1, us2, us3) = (ord(string[0]), + ord(string[1]), ord(string[2]), ord(string[3]), + ord(string[4]), ord(string[5]), ord(string[6]), + ord(string[7]), ord(string[8]), ord(string[9])) + self._year = yhi * 256 + ylo + self._microsecond = (((us1 << 8) | us2) << 8) | us3 + self._tzinfo = tzinfo + + def __reduce__(self): + return (self.__class__, self._getstate()) + + +datetime.min = datetime(1, 1, 1) +datetime.max = datetime(9999, 12, 31, 23, 59, 59, 999999) +datetime.resolution = timedelta(microseconds=1) + + +def _isoweek1monday(year): + # Helper to calculate the day number of the Monday starting week 1 + # XXX This could be done more efficiently + THURSDAY = 3 + firstday = _ymd2ord(year, 1, 1) + firstweekday = (firstday + 6) % 7 # See weekday() above + week1monday = firstday - firstweekday + if firstweekday > THURSDAY: + week1monday += 7 + return week1monday + +""" +Some time zone algebra. For a datetime x, let + x.n = x stripped of its timezone -- its naive time. + x.o = x.utcoffset(), and assuming that doesn't raise an exception or + return None + x.d = x.dst(), and assuming that doesn't raise an exception or + return None + x.s = x's standard offset, x.o - x.d + +Now some derived rules, where k is a duration (timedelta). + +1. x.o = x.s + x.d + This follows from the definition of x.s. + +2. If x and y have the same tzinfo member, x.s = y.s. + This is actually a requirement, an assumption we need to make about + sane tzinfo classes. + +3. The naive UTC time corresponding to x is x.n - x.o. + This is again a requirement for a sane tzinfo class. + +4. (x+k).s = x.s + This follows from #2, and that datimetimetz+timedelta preserves tzinfo. + +5. (x+k).n = x.n + k + Again follows from how arithmetic is defined. + +Now we can explain tz.fromutc(x). Let's assume it's an interesting case +(meaning that the various tzinfo methods exist, and don't blow up or return +None when called). + +The function wants to return a datetime y with timezone tz, equivalent to x. +x is already in UTC. + +By #3, we want + + y.n - y.o = x.n [1] + +The algorithm starts by attaching tz to x.n, and calling that y. So +x.n = y.n at the start. Then it wants to add a duration k to y, so that [1] +becomes true; in effect, we want to solve [2] for k: + + (y+k).n - (y+k).o = x.n [2] + +By #1, this is the same as + + (y+k).n - ((y+k).s + (y+k).d) = x.n [3] + +By #5, (y+k).n = y.n + k, which equals x.n + k because x.n=y.n at the start. +Substituting that into [3], + + x.n + k - (y+k).s - (y+k).d = x.n; the x.n terms cancel, leaving + k - (y+k).s - (y+k).d = 0; rearranging, + k = (y+k).s - (y+k).d; by #4, (y+k).s == y.s, so + k = y.s - (y+k).d + +On the RHS, (y+k).d can't be computed directly, but y.s can be, and we +approximate k by ignoring the (y+k).d term at first. Note that k can't be +very large, since all offset-returning methods return a duration of magnitude +less than 24 hours. For that reason, if y is firmly in std time, (y+k).d must +be 0, so ignoring it has no consequence then. + +In any case, the new value is + + z = y + y.s [4] + +It's helpful to step back at look at [4] from a higher level: it's simply +mapping from UTC to tz's standard time. + +At this point, if + + z.n - z.o = x.n [5] + +we have an equivalent time, and are almost done. The insecurity here is +at the start of daylight time. Picture US Eastern for concreteness. The wall +time jumps from 1:59 to 3:00, and wall hours of the form 2:MM don't make good +sense then. The docs ask that an Eastern tzinfo class consider such a time to +be EDT (because it's "after 2"), which is a redundant spelling of 1:MM EST +on the day DST starts. We want to return the 1:MM EST spelling because that's +the only spelling that makes sense on the local wall clock. + +In fact, if [5] holds at this point, we do have the standard-time spelling, +but that takes a bit of proof. We first prove a stronger result. What's the +difference between the LHS and RHS of [5]? Let + + diff = x.n - (z.n - z.o) [6] + +Now + z.n = by [4] + (y + y.s).n = by #5 + y.n + y.s = since y.n = x.n + x.n + y.s = since z and y are have the same tzinfo member, + y.s = z.s by #2 + x.n + z.s + +Plugging that back into [6] gives + + diff = + x.n - ((x.n + z.s) - z.o) = expanding + x.n - x.n - z.s + z.o = cancelling + - z.s + z.o = by #2 + z.d + +So diff = z.d. + +If [5] is true now, diff = 0, so z.d = 0 too, and we have the standard-time +spelling we wanted in the endcase described above. We're done. Contrarily, +if z.d = 0, then we have a UTC equivalent, and are also done. + +If [5] is not true now, diff = z.d != 0, and z.d is the offset we need to +add to z (in effect, z is in tz's standard time, and we need to shift the +local clock into tz's daylight time). + +Let + + z' = z + z.d = z + diff [7] + +and we can again ask whether + + z'.n - z'.o = x.n [8] + +If so, we're done. If not, the tzinfo class is insane, according to the +assumptions we've made. This also requires a bit of proof. As before, let's +compute the difference between the LHS and RHS of [8] (and skipping some of +the justifications for the kinds of substitutions we've done several times +already): + + diff' = x.n - (z'.n - z'.o) = replacing z'.n via [7] + x.n - (z.n + diff - z'.o) = replacing diff via [6] + x.n - (z.n + x.n - (z.n - z.o) - z'.o) = + x.n - z.n - x.n + z.n - z.o + z'.o = cancel x.n + - z.n + z.n - z.o + z'.o = cancel z.n + - z.o + z'.o = #1 twice + -z.s - z.d + z'.s + z'.d = z and z' have same tzinfo + z'.d - z.d + +So z' is UTC-equivalent to x iff z'.d = z.d at this point. If they are equal, +we've found the UTC-equivalent so are done. In fact, we stop with [7] and +return z', not bothering to compute z'.d. + +How could z.d and z'd differ? z' = z + z.d [7], so merely moving z' by +a dst() offset, and starting *from* a time already in DST (we know z.d != 0), +would have to change the result dst() returns: we start in DST, and moving +a little further into it takes us out of DST. + +There isn't a sane case where this can happen. The closest it gets is at +the end of DST, where there's an hour in UTC with no spelling in a hybrid +tzinfo class. In US Eastern, that's 5:MM UTC = 0:MM EST = 1:MM EDT. During +that hour, on an Eastern clock 1:MM is taken as being in standard time (6:MM +UTC) because the docs insist on that, but 0:MM is taken as being in daylight +time (4:MM UTC). There is no local time mapping to 5:MM UTC. The local +clock jumps from 1:59 back to 1:00 again, and repeats the 1:MM hour in +standard time. Since that's what the local clock *does*, we want to map both +UTC hours 5:MM and 6:MM to 1:MM Eastern. The result is ambiguous +in local time, but so it goes -- it's the way the local clock works. + +When x = 5:MM UTC is the input to this algorithm, x.o=0, y.o=-5 and y.d=0, +so z=0:MM. z.d=60 (minutes) then, so [5] doesn't hold and we keep going. +z' = z + z.d = 1:MM then, and z'.d=0, and z'.d - z.d = -60 != 0 so [8] +(correctly) concludes that z' is not UTC-equivalent to x. + +Because we know z.d said z was in daylight time (else [5] would have held and +we would have stopped then), and we know z.d != z'.d (else [8] would have held +and we have stopped then), and there are only 2 possible values dst() can +return in Eastern, it follows that z'.d must be 0 (which it is in the example, +but the reasoning doesn't depend on the example -- it depends on there being +two possible dst() outcomes, one zero and the other non-zero). Therefore +z' must be in standard time, and is the spelling we want in this case. + +Note again that z' is not UTC-equivalent as far as the hybrid tzinfo class is +concerned (because it takes z' as being in standard time rather than the +daylight time we intend here), but returning it gives the real-life "local +clock repeats an hour" behavior when mapping the "unspellable" UTC hour into +tz. + +When the input is 6:MM, z=1:MM and z.d=0, and we stop at once, again with +the 1:MM standard time spelling we want. + +So how can this break? One of the assumptions must be violated. Two +possibilities: + +1) [2] effectively says that y.s is invariant across all y belong to a given + time zone. This isn't true if, for political reasons or continental drift, + a region decides to change its base offset from UTC. + +2) There may be versions of "double daylight" time where the tail end of + the analysis gives up a step too early. I haven't thought about that + enough to say. + +In any case, it's clear that the default fromutc() is strong enough to handle +"almost all" time zones: so long as the standard offset is invariant, it +doesn't matter if daylight time transition points change from year to year, or +if daylight time is skipped in some years; it doesn't matter how large or +small dst() may get within its bounds; and it doesn't even matter if some +perverse time zone returns a negative dst()). So a breaking case must be +pretty bizarre, and a tzinfo subclass can override fromutc() if it is. +""" diff --git a/third_party/pythonparser/LICENSE.txt b/third_party/pythonparser/LICENSE.txt new file mode 100644 index 00000000..8da32b6e --- /dev/null +++ b/third_party/pythonparser/LICENSE.txt @@ -0,0 +1,22 @@ +Copyright (c) 2015 whitequark + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/third_party/pythonparser/README.md b/third_party/pythonparser/README.md new file mode 100644 index 00000000..72c62e85 --- /dev/null +++ b/third_party/pythonparser/README.md @@ -0,0 +1,4 @@ +The source code in this directory is forked from +[github.com/m-labs/pythonparser](https://github.com/m-labs/pythonparser). +There are very light modifications to the source code so that it will work with +Grumpy. diff --git a/third_party/pythonparser/__init__.py b/third_party/pythonparser/__init__.py new file mode 100644 index 00000000..a501b376 --- /dev/null +++ b/third_party/pythonparser/__init__.py @@ -0,0 +1,61 @@ +from __future__ import absolute_import, division, print_function, unicode_literals +import sys +from . import source as pythonparser_source, lexer as pythonparser_lexer, parser as pythonparser_parser, diagnostic as pythonparser_diagnostic + +def parse_buffer(buffer, mode="exec", flags=[], version=None, engine=None): + """ + Like :meth:`parse`, but accepts a :class:`source.Buffer` instead of + source and filename, and returns comments as well. + + :see: :meth:`parse` + :return: (:class:`ast.AST`, list of :class:`source.Comment`) + Abstract syntax tree and comments + """ + + if version is None: + version = sys.version_info[0:2] + + if engine is None: + engine = pythonparser_diagnostic.Engine() + + lexer = pythonparser_lexer.Lexer(buffer, version, engine) + if mode in ("single", "eval"): + lexer.interactive = True + + parser = pythonparser_parser.Parser(lexer, version, engine) + parser.add_flags(flags) + + if mode == "exec": + return parser.file_input(), lexer.comments + elif mode == "single": + return parser.single_input(), lexer.comments + elif mode == "eval": + return parser.eval_input(), lexer.comments + +def parse(source, filename="", mode="exec", + flags=[], version=None, engine=None): + """ + Parse a string into an abstract syntax tree. + This is the replacement for the built-in :meth:`..ast.parse`. + + :param source: (string) Source code in the correct encoding + :param filename: (string) Filename of the source (used in diagnostics) + :param mode: (string) Execution mode. Pass ``"exec"`` to parse a module, + ``"single"`` to parse a single (interactive) statement, + and ``"eval"`` to parse an expression. In the last two cases, + ``source`` must be terminated with an empty line + (i.e. end with ``"\\n\\n"``). + :param flags: (list of string) Future flags. + Equivalent to ``from __future__ import ``. + :param version: (2-tuple of int) Major and minor version of Python + syntax to recognize, ``sys.version_info[0:2]`` by default. + :param engine: (:class:`diagnostic.Engine`) Diagnostic engine, + a fresh one is created by default + :return: (:class:`ast.AST`) Abstract syntax tree + :raise: :class:`diagnostic.Error` + if the source code is not well-formed + """ + ast, comments = parse_buffer(pythonparser_source.Buffer(source, filename), + mode, flags, version, engine) + return ast + diff --git a/third_party/pythonparser/algorithm.py b/third_party/pythonparser/algorithm.py new file mode 100644 index 00000000..d9bed74a --- /dev/null +++ b/third_party/pythonparser/algorithm.py @@ -0,0 +1,117 @@ +""" +The :mod:`Diagnostic` module provides several commonly useful +algorithms that operate on abstract syntax trees. +""" + +from __future__ import absolute_import, division, print_function, unicode_literals +from . import ast + +class Visitor: + """ + A node visitor base class that does a traversal + of the abstract syntax tree. + + This class is meant to be subclassed, with the subclass adding + visitor methods. The visitor method should call ``self.generic_visit(node)`` + to continue the traversal; this allows to perform arbitrary + actions both before and after traversing the children of a node. + + The visitor methods for the nodes are ``'visit_'`` + + class name of the node. So a `Try` node visit function would + be `visit_Try`. + """ + + def generic_visit(self, node): + """Called if no explicit visitor function exists for a node.""" + for field_name in node._fields: + self.visit(getattr(node, field_name)) + + def _visit_one(self, node): + visit_attr = "visit_" + type(node).__name__ + if hasattr(self, visit_attr): + return getattr(self, visit_attr)(node) + else: + return self.generic_visit(node) + + def visit(self, obj): + """Visit a node or a list of nodes. Other values are ignored""" + if isinstance(obj, list): + return [self.visit(elt) for elt in obj] + elif isinstance(obj, ast.AST): + return self._visit_one(obj) + +class Transformer: + """ + A node transformer base class that does a post-order traversal + of the abstract syntax tree while allowing to replace or remove + the nodes being traversed. + + The return value of the visitor methods is used to replace or remove + the old node. If the return value of the visitor method is ``None``, + the node will be removed from its location, otherwise it is replaced + with the return value. The return value may be the original node + in which case no replacement takes place. + + This class is meant to be subclassed, with the subclass adding + visitor methods. The visitor method should call ``self.generic_visit(node)`` + to continue the traversal; this allows to perform arbitrary + actions both before and after traversing the children of a node. + + The visitor methods for the nodes are ``'visit_'`` + + class name of the node. So a `Try` node visit function would + be `visit_Try`. + """ + + def generic_visit(self, node): + """Called if no explicit visitor function exists for a node.""" + for field_name in node._fields: + setattr(node, field_name, self.visit(getattr(node, field_name))) + return node + + def _visit_one(self, node): + visit_attr = "visit_" + type(node).__name__ + if hasattr(self, visit_attr): + return getattr(self, visit_attr)(node) + else: + return self.generic_visit(node) + + def visit(self, obj): + """Visit a node or a list of nodes. Other values are ignored""" + if isinstance(obj, list): + return list(filter(lambda x: x is not None, map(self.visit, obj))) + elif isinstance(obj, ast.AST): + return self._visit_one(obj) + else: + return obj + +def compare(left, right, compare_locs=False): + """ + An AST comparison function. Returns ``True`` if all fields in + ``left`` are equal to fields in ``right``; if ``compare_locs`` is + true, all locations should match as well. + """ + if type(left) != type(right): + return False + + if isinstance(left, ast.AST): + for field in left._fields: + if not compare(getattr(left, field), getattr(right, field)): + return False + + if compare_locs: + for loc in left._locs: + if getattr(left, loc) != getattr(right, loc): + return False + + return True + elif isinstance(left, list): + if len(left) != len(right): + return False + + for left_elt, right_elt in zip(left, right): + if not compare(left_elt, right_elt): + return False + + return True + else: + return left == right diff --git a/third_party/pythonparser/ast.py b/third_party/pythonparser/ast.py new file mode 100644 index 00000000..8b93ab0e --- /dev/null +++ b/third_party/pythonparser/ast.py @@ -0,0 +1,807 @@ +# encoding: utf-8 + +""" +The :mod:`ast` module contains the classes comprising the Python abstract syntax tree. + +All attributes ending with ``loc`` contain instances of :class:`.source.Range` +or None. All attributes ending with ``_locs`` contain lists of instances of +:class:`.source.Range` or []. + +The attribute ``loc``, present in every class except those inheriting :class:`boolop`, +has a special meaning: it encompasses the entire AST node, so that it is possible +to cut the range contained inside ``loc`` of a parsetree fragment and paste it +somewhere else without altering said parsetree fragment that. + +The AST format for all supported versions is generally normalized to be a superset +of the native :mod:`..ast` module of the latest supported Python version. +In particular this affects: + + * :class:`With`: on 2.6-2.7 it uses the 3.0 format. + * :class:`TryExcept` and :class:`TryFinally`: on 2.6-2.7 they're replaced with + :class:`Try` from 3.0. + * :class:`arguments`: on 2.6-3.1 it uses the 3.2 format, with dedicated + :class:`arg` in ``vararg`` and ``kwarg`` slots. +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +# Location mixins + +class commonloc(object): + """ + A mixin common for all nodes. + + :cvar _locs: (tuple of strings) + names of all attributes with location values + + :ivar loc: range encompassing all locations defined for this node + or its children + """ + + _locs = ("loc",) + + def _reprfields(self): + return self._fields + self._locs + + def __repr__(self): + def value(name): + try: + loc = self.__dict__[name] + if isinstance(loc, list): + return "[%s]" % (", ".join(map(repr, loc))) + else: + return repr(loc) + except: + return "(!!!MISSING!!!)" + fields = ", ".join(map(lambda name: "%s=%s" % (name, value(name)), + self._reprfields())) + return "%s(%s)" % (self.__class__.__name__, fields) + + @property + def lineno(self): + return self.loc.line() + +class keywordloc(commonloc): + """ + A mixin common for all keyword statements, e.g. ``pass`` and ``yield expr``. + + :ivar keyword_loc: location of the keyword, e.g. ``yield``. + """ + _locs = commonloc._locs + ("keyword_loc",) + +class beginendloc(commonloc): + """ + A mixin common for nodes with a opening and closing delimiters, e.g. tuples and lists. + + :ivar begin_loc: location of the opening delimiter, e.g. ``(``. + :ivar end_loc: location of the closing delimiter, e.g. ``)``. + """ + _locs = commonloc._locs + ("begin_loc", "end_loc") + +# AST nodes + +class AST(object): + """ + An ancestor of all nodes. + + :cvar _fields: (tuple of strings) + names of all attributes with semantic values + """ + _fields = () + + def __init__(self, **fields): + for field in fields: + setattr(self, field, fields[field]) + +class alias(AST, commonloc): + """ + An import alias, e.g. ``x as y``. + + :ivar name: (string) value to import + :ivar asname: (string) name to add to the environment + :ivar name_loc: location of name + :ivar as_loc: location of ``as`` + :ivar asname_loc: location of asname + """ + _fields = ("name", "asname") + _locs = commonloc._locs + ("name_loc", "as_loc", "asname_loc") + +class arg(AST, commonloc): + """ + A formal argument, e.g. in ``def f(x)`` or ``def f(x: T)``. + + :ivar arg: (string) argument name + :ivar annotation: (:class:`AST`) type annotation, if any; **emitted since 3.0** + :ivar arg_loc: location of argument name + :ivar colon_loc: location of ``:``, if any; **emitted since 3.0** + """ + _fields = ("arg", "annotation") + _locs = commonloc._locs + ("arg_loc", "colon_loc") + +class arguments(AST, beginendloc): + """ + Function definition arguments, e.g. in ``def f(x, y=1, *z, **t)``. + + :ivar args: (list of :class:`arg`) regular formal arguments + :ivar defaults: (list of :class:`AST`) values of default arguments + :ivar vararg: (:class:`arg`) splat formal argument (if any), e.g. in ``*x`` + :ivar kwonlyargs: (list of :class:`arg`) keyword-only (post-\*) formal arguments; + **emitted since 3.0** + :ivar kw_defaults: (list of :class:`AST`) values of default keyword-only arguments; + **emitted since 3.0** + :ivar kwarg: (:class:`arg`) keyword splat formal argument (if any), e.g. in ``**x`` + :ivar star_loc: location of ``*``, if any + :ivar dstar_loc: location of ``**``, if any + :ivar equals_locs: locations of ``=`` + :ivar kw_equals_locs: locations of ``=`` of default keyword-only arguments; + **emitted since 3.0** + """ + _fields = ("args", "vararg", "kwonlyargs", "kwarg", "defaults", "kw_defaults") + _locs = beginendloc._locs + ("star_loc", "dstar_loc", "equals_locs", "kw_equals_locs") + +class boolop(AST, commonloc): + """ + Base class for binary boolean operators. + + This class is unlike others in that it does not have the ``loc`` field. + It serves only as an indicator of operation and corresponds to no source + itself; locations are recorded in :class:`BoolOp`. + """ + _locs = () +class And(boolop): + """The ``and`` operator.""" +class Or(boolop): + """The ``or`` operator.""" + +class cmpop(AST, commonloc): + """Base class for comparison operators.""" +class Eq(cmpop): + """The ``==`` operator.""" +class Gt(cmpop): + """The ``>`` operator.""" +class GtE(cmpop): + """The ``>=`` operator.""" +class In(cmpop): + """The ``in`` operator.""" +class Is(cmpop): + """The ``is`` operator.""" +class IsNot(cmpop): + """The ``is not`` operator.""" +class Lt(cmpop): + """The ``<`` operator.""" +class LtE(cmpop): + """The ``<=`` operator.""" +class NotEq(cmpop): + """The ``!=`` (or deprecated ``<>``) operator.""" +class NotIn(cmpop): + """The ``not in`` operator.""" + +class comprehension(AST, commonloc): + """ + A single ``for`` list comprehension clause. + + :ivar target: (assignable :class:`AST`) the variable(s) bound in comprehension body + :ivar iter: (:class:`AST`) the expression being iterated + :ivar ifs: (list of :class:`AST`) the ``if`` clauses + :ivar for_loc: location of the ``for`` keyword + :ivar in_loc: location of the ``in`` keyword + :ivar if_locs: locations of ``if`` keywords + """ + _fields = ("target", "iter", "ifs") + _locs = commonloc._locs + ("for_loc", "in_loc", "if_locs") + +class excepthandler(AST, commonloc): + """Base class for the exception handler.""" +class ExceptHandler(excepthandler): + """ + An exception handler, e.g. ``except x as y:· z``. + + :ivar type: (:class:`AST`) type of handled exception, if any + :ivar name: (assignable :class:`AST` **until 3.0**, string **since 3.0**) + variable bound to exception, if any + :ivar body: (list of :class:`AST`) code to execute when exception is caught + :ivar except_loc: location of ``except`` + :ivar as_loc: location of ``as``, if any + :ivar name_loc: location of variable name + :ivar colon_loc: location of ``:`` + """ + _fields = ("type", "name", "body") + _locs = excepthandler._locs + ("except_loc", "as_loc", "name_loc", "colon_loc") + +class expr(AST, commonloc): + """Base class for expression nodes.""" +class Attribute(expr): + """ + An attribute access, e.g. ``x.y``. + + :ivar value: (:class:`AST`) left-hand side + :ivar attr: (string) attribute name + """ + _fields = ("value", "attr", "ctx") + _locs = expr._locs + ("dot_loc", "attr_loc") +class BinOp(expr): + """ + A binary operation, e.g. ``x + y``. + + :ivar left: (:class:`AST`) left-hand side + :ivar op: (:class:`operator`) operator + :ivar right: (:class:`AST`) right-hand side + """ + _fields = ("left", "op", "right") +class BoolOp(expr): + """ + A boolean operation, e.g. ``x and y``. + + :ivar op: (:class:`boolop`) operator + :ivar values: (list of :class:`AST`) operands + :ivar op_locs: locations of operators + """ + _fields = ("op", "values") + _locs = expr._locs + ("op_locs",) +class Call(expr, beginendloc): + """ + A function call, e.g. ``f(x, y=1, *z, **t)``. + + :ivar func: (:class:`AST`) function to call + :ivar args: (list of :class:`AST`) regular arguments + :ivar keywords: (list of :class:`keyword`) keyword arguments + :ivar starargs: (:class:`AST`) splat argument (if any), e.g. in ``*x`` + :ivar kwargs: (:class:`AST`) keyword splat argument (if any), e.g. in ``**x`` + :ivar star_loc: location of ``*``, if any + :ivar dstar_loc: location of ``**``, if any + """ + _fields = ("func", "args", "keywords", "starargs", "kwargs") + _locs = beginendloc._locs + ("star_loc", "dstar_loc") +class Compare(expr): + """ + A comparison operation, e.g. ``x < y`` or ``x < y > z``. + + :ivar left: (:class:`AST`) left-hand + :ivar ops: (list of :class:`cmpop`) compare operators + :ivar comparators: (list of :class:`AST`) compare values + """ + _fields = ("left", "ops", "comparators") +class Dict(expr, beginendloc): + """ + A dictionary, e.g. ``{x: y}``. + + :ivar keys: (list of :class:`AST`) keys + :ivar values: (list of :class:`AST`) values + :ivar colon_locs: locations of ``:`` + """ + _fields = ("keys", "values") + _locs = beginendloc._locs + ("colon_locs",) +class DictComp(expr, beginendloc): + """ + A list comprehension, e.g. ``{x: y for x,y in z}``. + + **Emitted since 2.7.** + + :ivar key: (:class:`AST`) key part of comprehension body + :ivar value: (:class:`AST`) value part of comprehension body + :ivar generators: (list of :class:`comprehension`) ``for`` clauses + :ivar colon_loc: location of ``:`` + """ + _fields = ("key", "value", "generators") + _locs = beginendloc._locs + ("colon_loc",) +class Ellipsis(expr): + """The ellipsis, e.g. in ``x[...]``.""" +class GeneratorExp(expr, beginendloc): + """ + A generator expression, e.g. ``(x for x in y)``. + + :ivar elt: (:class:`AST`) expression body + :ivar generators: (list of :class:`comprehension`) ``for`` clauses + """ + _fields = ("elt", "generators") +class IfExp(expr): + """ + A conditional expression, e.g. ``x if y else z``. + + :ivar test: (:class:`AST`) condition + :ivar body: (:class:`AST`) value if true + :ivar orelse: (:class:`AST`) value if false + :ivar if_loc: location of ``if`` + :ivar else_loc: location of ``else`` + """ + _fields = ("test", "body", "orelse") + _locs = expr._locs + ("if_loc", "else_loc") +class Lambda(expr): + """ + A lambda expression, e.g. ``lambda x: x*x``. + + :ivar args: (:class:`arguments`) arguments + :ivar body: (:class:`AST`) body + :ivar lambda_loc: location of ``lambda`` + :ivar colon_loc: location of ``:`` + """ + _fields = ("args", "body") + _locs = expr._locs + ("lambda_loc", "colon_loc") +class List(expr, beginendloc): + """ + A list, e.g. ``[x, y]``. + + :ivar elts: (list of :class:`AST`) elements + """ + _fields = ("elts", "ctx") +class ListComp(expr, beginendloc): + """ + A list comprehension, e.g. ``[x for x in y]``. + + :ivar elt: (:class:`AST`) comprehension body + :ivar generators: (list of :class:`comprehension`) ``for`` clauses + """ + _fields = ("elt", "generators") +class Name(expr): + """ + An identifier, e.g. ``x``. + + :ivar id: (string) name + """ + _fields = ("id", "ctx") +class NameConstant(expr): + """ + A named constant, e.g. ``None``. + + :ivar value: Python value, one of ``None``, ``True`` or ``False`` + """ + _fields = ("value",) +class Num(expr): + """ + An integer, floating point or complex number, e.g. ``1``, ``1.0`` or ``1.0j``. + + :ivar n: (int, float or complex) value + """ + _fields = ("n",) +class Repr(expr, beginendloc): + """ + A repr operation, e.g. ``\`x\``` + + **Emitted until 3.0.** + + :ivar value: (:class:`AST`) value + """ + _fields = ("value",) +class Set(expr, beginendloc): + """ + A set, e.g. ``{x, y}``. + + **Emitted since 2.7.** + + :ivar elts: (list of :class:`AST`) elements + """ + _fields = ("elts",) +class SetComp(expr, beginendloc): + """ + A set comprehension, e.g. ``{x for x in y}``. + + **Emitted since 2.7.** + + :ivar elt: (:class:`AST`) comprehension body + :ivar generators: (list of :class:`comprehension`) ``for`` clauses + """ + _fields = ("elt", "generators") +class Str(expr, beginendloc): + """ + A string, e.g. ``"x"``. + + :ivar s: (string) value + """ + _fields = ("s",) +class Starred(expr): + """ + A starred expression, e.g. ``*x`` in ``*x, y = z``. + + :ivar value: (:class:`AST`) expression + :ivar star_loc: location of ``*`` + """ + _fields = ("value", "ctx") + _locs = expr._locs + ("star_loc",) +class Subscript(expr, beginendloc): + """ + A subscript operation, e.g. ``x[1]``. + + :ivar value: (:class:`AST`) object being sliced + :ivar slice: (:class:`slice`) slice + """ + _fields = ("value", "slice", "ctx") +class Tuple(expr, beginendloc): + """ + A tuple, e.g. ``(x,)`` or ``x,y``. + + :ivar elts: (list of nodes) elements + """ + _fields = ("elts", "ctx") +class UnaryOp(expr): + """ + An unary operation, e.g. ``+x``. + + :ivar op: (:class:`unaryop`) operator + :ivar operand: (:class:`AST`) operand + """ + _fields = ("op", "operand") +class Yield(expr): + """ + A yield expression, e.g. ``yield x``. + + :ivar value: (:class:`AST`) yielded value + :ivar yield_loc: location of ``yield`` + """ + _fields = ("value",) + _locs = expr._locs + ("yield_loc",) +class YieldFrom(expr): + """ + A yield from expression, e.g. ``yield from x``. + + :ivar value: (:class:`AST`) yielded value + :ivar yield_loc: location of ``yield`` + :ivar from_loc: location of ``from`` + """ + _fields = ("value",) + _locs = expr._locs + ("yield_loc", "from_loc") + +# expr_context +# AugLoad +# AugStore +# Del +# Load +# Param +# Store + +class keyword(AST, commonloc): + """ + A keyword actual argument, e.g. in ``f(x=1)``. + + :ivar arg: (string) name + :ivar value: (:class:`AST`) value + :ivar equals_loc: location of ``=`` + """ + _fields = ("arg", "value") + _locs = commonloc._locs + ("arg_loc", "equals_loc") + +class mod(AST, commonloc): + """Base class for modules (groups of statements).""" + _fields = ("body",) +class Expression(mod): + """A group of statements parsed as if for :func:`eval`.""" +class Interactive(mod): + """A group of statements parsed as if it was REPL input.""" +class Module(mod): + """A group of statements parsed as if it was a file.""" + +class operator(AST, commonloc): + """Base class for numeric binary operators.""" +class Add(operator): + """The ``+`` operator.""" +class BitAnd(operator): + """The ``&`` operator.""" +class BitOr(operator): + """The ``|`` operator.""" +class BitXor(operator): + """The ``^`` operator.""" +class Div(operator): + """The ``\\`` operator.""" +class FloorDiv(operator): + """The ``\\\\`` operator.""" +class LShift(operator): + """The ``<<`` operator.""" +class MatMult(operator): + """The ``@`` operator.""" +class Mod(operator): + """The ``%`` operator.""" +class Mult(operator): + """The ``*`` operator.""" +class Pow(operator): + """The ``**`` operator.""" +class RShift(operator): + """The ``>>`` operator.""" +class Sub(operator): + """The ``-`` operator.""" + +class slice(AST, commonloc): + """Base class for slice operations.""" +class ExtSlice(slice): + """ + The multiple slice, e.g. in ``x[0:1, 2:3]``. + Note that multiple slices with only integer indexes + will appear as instances of :class:`Index`. + + :ivar dims: (:class:`slice`) sub-slices + """ + _fields = ("dims",) +class Index(slice): + """ + The index, e.g. in ``x[1]`` or ``x[1, 2]``. + + :ivar value: (:class:`AST`) index + """ + _fields = ("value",) +class Slice(slice): + """ + The slice, e.g. in ``x[0:1]`` or ``x[0:1:2]``. + + :ivar lower: (:class:`AST`) lower bound, if any + :ivar upper: (:class:`AST`) upper bound, if any + :ivar step: (:class:`AST`) iteration step, if any + :ivar bound_colon_loc: location of first semicolon + :ivar step_colon_loc: location of second semicolon, if any + """ + _fields = ("lower", "upper", "step") + _locs = slice._locs + ("bound_colon_loc", "step_colon_loc") + +class stmt(AST, commonloc): + """Base class for statement nodes.""" +class Assert(stmt, keywordloc): + """ + The ``assert x, msg`` statement. + + :ivar test: (:class:`AST`) condition + :ivar msg: (:class:`AST`) message, if any + """ + _fields = ("test", "msg") +class Assign(stmt): + """ + The ``=`` statement, e.g. in ``x = 1`` or ``x = y = 1``. + + :ivar targets: (list of assignable :class:`AST`) left-hand sides + :ivar value: (:class:`AST`) right-hand side + :ivar op_locs: location of equality signs corresponding to ``targets`` + """ + _fields = ("targets", "value") + _locs = stmt._locs + ("op_locs",) +class AugAssign(stmt): + """ + The operator-assignment statement, e.g. ``+=``. + + :ivar target: (assignable :class:`AST`) left-hand side + :ivar op: (:class:`operator`) operator + :ivar value: (:class:`AST`) right-hand side + """ + _fields = ("target", "op", "value") +class Break(stmt, keywordloc): + """The ``break`` statement.""" +class ClassDef(stmt, keywordloc): + """ + The ``class x(z, y):· t`` (2.6) or + ``class x(y, z=1, *t, **u):· v`` (3.0) statement. + + :ivar name: (string) name + :ivar bases: (list of :class:`AST`) base classes + :ivar keywords: (list of :class:`keyword`) keyword arguments; **emitted since 3.0** + :ivar starargs: (:class:`AST`) splat argument (if any), e.g. in ``*x``; **emitted since 3.0** + :ivar kwargs: (:class:`AST`) keyword splat argument (if any), e.g. in ``**x``; **emitted since 3.0** + :ivar body: (list of :class:`AST`) body + :ivar decorator_list: (list of :class:`AST`) decorators + :ivar keyword_loc: location of ``class`` + :ivar name_loc: location of name + :ivar lparen_loc: location of ``(``, if any + :ivar star_loc: location of ``*``, if any; **emitted since 3.0** + :ivar dstar_loc: location of ``**``, if any; **emitted since 3.0** + :ivar rparen_loc: location of ``)``, if any + :ivar colon_loc: location of ``:`` + :ivar at_locs: locations of decorator ``@`` + """ + _fields = ("name", "bases", "keywords", "starargs", "kwargs", "body", "decorator_list") + _locs = keywordloc._locs + ("name_loc", "lparen_loc", "star_loc", "dstar_loc", "rparen_loc", + "colon_loc", "at_locs") +class Continue(stmt, keywordloc): + """The ``continue`` statement.""" +class Delete(stmt, keywordloc): + """ + The ``del x, y`` statement. + + :ivar targets: (list of :class:`Name`) + """ + _fields = ("targets",) +class Exec(stmt, keywordloc): + """ + The ``exec code in locals, globals`` statement. + + **Emitted until 3.0.** + + :ivar body: (:class:`AST`) code + :ivar locals: (:class:`AST`) locals + :ivar globals: (:class:`AST`) globals + :ivar keyword_loc: location of ``exec`` + :ivar in_loc: location of ``in`` + """ + _fields = ("body", "locals", "globals") + _locs = keywordloc._locs + ("in_loc",) +class Expr(stmt): + """ + An expression in statement context. The value of expression is discarded. + + :ivar value: (:class:`expr`) value + """ + _fields = ("value",) +class For(stmt, keywordloc): + """ + The ``for x in y:· z·else:· t`` statement. + + :ivar target: (assignable :class:`AST`) loop variable + :ivar iter: (:class:`AST`) loop collection + :ivar body: (list of :class:`AST`) code for every iteration + :ivar orelse: (list of :class:`AST`) code if empty + :ivar keyword_loc: location of ``for`` + :ivar in_loc: location of ``in`` + :ivar for_colon_loc: location of colon after ``for`` + :ivar else_loc: location of ``else``, if any + :ivar else_colon_loc: location of colon after ``else``, if any + """ + _fields = ("target", "iter", "body", "orelse") + _locs = keywordloc._locs + ("in_loc", "for_colon_loc", "else_loc", "else_colon_loc") +class FunctionDef(stmt, keywordloc): + """ + The ``def f(x):· y`` (2.6) or ``def f(x) -> t:· y`` (3.0) statement. + + :ivar name: (string) name + :ivar args: (:class:`arguments`) formal arguments + :ivar returns: (:class:`AST`) return type annotation; **emitted since 3.0** + :ivar body: (list of :class:`AST`) body + :ivar decorator_list: (list of :class:`AST`) decorators + :ivar keyword_loc: location of ``def`` + :ivar name_loc: location of name + :ivar arrow_loc: location of ``->``, if any; **emitted since 3.0** + :ivar colon_loc: location of ``:``, if any + :ivar at_locs: locations of decorator ``@`` + """ + _fields = ("name", "args", "returns", "body", "decorator_list") + _locs = keywordloc._locs + ("name_loc", "arrow_loc", "colon_loc", "at_locs") +class Global(stmt, keywordloc): + """ + The ``global x, y`` statement. + + :ivar names: (list of string) names + :ivar name_locs: locations of names + """ + _fields = ("names",) + _locs = keywordloc._locs + ("name_locs",) +class If(stmt, keywordloc): + """ + The ``if x:· y·else:· z`` or ``if x:· y·elif: z· t`` statement. + + :ivar test: (:class:`AST`) condition + :ivar body: (list of :class:`AST`) code if true + :ivar orelse: (list of :class:`AST`) code if false + :ivar if_colon_loc: location of colon after ``if`` or ``elif`` + :ivar else_loc: location of ``else``, if any + :ivar else_colon_loc: location of colon after ``else``, if any + """ + _fields = ("test", "body", "orelse") + _locs = keywordloc._locs + ("if_colon_loc", "else_loc", "else_colon_loc") +class Import(stmt, keywordloc): + """ + The ``import x, y`` statement. + + :ivar names: (list of :class:`alias`) names + """ + _fields = ("names",) +class ImportFrom(stmt, keywordloc): + """ + The ``from ...x import y, z`` or ``from x import (y, z)`` or + ``from x import *`` statement. + + :ivar names: (list of :class:`alias`) names + :ivar module: (string) module name, if any + :ivar level: (integer) amount of dots before module name + :ivar keyword_loc: location of ``from`` + :ivar dots_loc: location of dots, if any + :ivar module_loc: location of module name, if any + :ivar import_loc: location of ``import`` + :ivar lparen_loc: location of ``(``, if any + :ivar rparen_loc: location of ``)``, if any + """ + _fields = ("names", "module", "level") + _locs = keywordloc._locs + ("dots_loc", "module_loc", "import_loc", "lparen_loc", "rparen_loc") +class Nonlocal(stmt, keywordloc): + """ + The ``nonlocal x, y`` statement. + + **Emitted since 3.0.** + + :ivar names: (list of string) names + :ivar name_locs: locations of names + """ + _fields = ("names",) + _locs = keywordloc._locs + ("name_locs",) +class Pass(stmt, keywordloc): + """The ``pass`` statement.""" +class Print(stmt, keywordloc): + """ + The ``print >>x, y, z,`` statement. + + **Emitted until 3.0 or until print_function future flag is activated.** + + :ivar dest: (:class:`AST`) destination stream, if any + :ivar values: (list of :class:`AST`) values to print + :ivar nl: (boolean) whether to print newline after values + :ivar dest_loc: location of ``>>`` + """ + _fields = ("dest", "values", "nl") + _locs = keywordloc._locs + ("dest_loc",) +class Raise(stmt, keywordloc): + """ + The ``raise exc, arg, traceback`` (2.6) or + or ``raise exc from cause`` (3.0) statement. + + :ivar exc: (:class:`AST`) exception type or instance + :ivar cause: (:class:`AST`) cause of exception, if any; **emitted since 3.0** + :ivar inst: (:class:`AST`) exception instance or argument list, if any; **emitted until 3.0** + :ivar tback: (:class:`AST`) traceback, if any; **emitted until 3.0** + :ivar from_loc: location of ``from``, if any; **emitted since 3.0** + """ + _fields = ("exc", "cause", "inst", "tback") + _locs = keywordloc._locs + ("from_loc",) +class Return(stmt, keywordloc): + """ + The ``return x`` statement. + + :ivar value: (:class:`AST`) return value, if any + """ + _fields = ("value",) +class Try(stmt, keywordloc): + """ + The ``try:· x·except y:· z·else:· t`` or + ``try:· x·finally:· y`` statement. + + :ivar body: (list of :class:`AST`) code to try + :ivar handlers: (list of :class:`ExceptHandler`) exception handlers + :ivar orelse: (list of :class:`AST`) code if no exception + :ivar finalbody: (list of :class:`AST`) code to finalize + :ivar keyword_loc: location of ``try`` + :ivar try_colon_loc: location of ``:`` after ``try`` + :ivar else_loc: location of ``else`` + :ivar else_colon_loc: location of ``:`` after ``else`` + :ivar finally_loc: location of ``finally`` + :ivar finally_colon_loc: location of ``:`` after ``finally`` + """ + _fields = ("body", "handlers", "orelse", "finalbody") + _locs = keywordloc._locs + ("try_colon_loc", "else_loc", "else_colon_loc", + "finally_loc", "finally_colon_loc",) +class While(stmt, keywordloc): + """ + The ``while x:· y·else:· z`` statement. + + :ivar test: (:class:`AST`) condition + :ivar body: (list of :class:`AST`) code for every iteration + :ivar orelse: (list of :class:`AST`) code if empty + :ivar keyword_loc: location of ``while`` + :ivar while_colon_loc: location of colon after ``while`` + :ivar else_loc: location of ``else``, if any + :ivar else_colon_loc: location of colon after ``else``, if any + """ + _fields = ("test", "body", "orelse") + _locs = keywordloc._locs + ("while_colon_loc", "else_loc", "else_colon_loc") +class With(stmt, keywordloc): + """ + The ``with x as y:· z`` statement. + + :ivar items: (list of :class:`withitem`) bindings + :ivar body: (:class:`AST`) body + :ivar keyword_loc: location of ``with`` + :ivar colon_loc: location of ``:`` + """ + _fields = ("items", "body") + _locs = keywordloc._locs + ("colon_loc",) + +class unaryop(AST, commonloc): + """Base class for unary numeric and boolean operators.""" +class Invert(unaryop): + """The ``~`` operator.""" +class Not(unaryop): + """The ``not`` operator.""" +class UAdd(unaryop): + """The unary ``+`` operator.""" +class USub(unaryop): + """The unary ``-`` operator.""" + +class withitem(AST, commonloc): + """ + The ``x as y`` clause in ``with x as y:``. + + :ivar context_expr: (:class:`AST`) context + :ivar optional_vars: (assignable :class:`AST`) context binding, if any + :ivar as_loc: location of ``as``, if any + """ + _fields = ("context_expr", "optional_vars") + _locs = commonloc._locs + ("as_loc",) diff --git a/third_party/pythonparser/diagnostic.py b/third_party/pythonparser/diagnostic.py new file mode 100644 index 00000000..45eb3b4e --- /dev/null +++ b/third_party/pythonparser/diagnostic.py @@ -0,0 +1,178 @@ +""" +The :mod:`Diagnostic` module concerns itself with processing +and presentation of diagnostic messages. +""" + +from __future__ import absolute_import, division, print_function, unicode_literals +from functools import reduce +from contextlib import contextmanager +import sys, re + +class Diagnostic: + """ + A diagnostic message highlighting one or more locations + in a single source buffer. + + :ivar level: (one of ``LEVELS``) severity level + :ivar reason: (format string) diagnostic message + :ivar arguments: (dictionary) substitutions for ``reason`` + :ivar location: (:class:`pythonparser.source.Range`) most specific + location of the problem + :ivar highlights: (list of :class:`pythonparser.source.Range`) + secondary locations related to the problem that are + likely to be on the same line + :ivar notes: (list of :class:`Diagnostic`) + secondary diagnostics highlighting relevant source + locations that are unlikely to be on the same line + """ + + LEVELS = ["note", "warning", "error", "fatal"] + """ + Available diagnostic levels: + * ``fatal`` indicates an unrecoverable error. + * ``error`` indicates an error that leaves a possibility of + processing more code, e.g. a recoverable parsing error. + * ``warning`` indicates a potential problem. + * ``note`` level diagnostics do not appear by itself, + but are attached to other diagnostics to refer to + and describe secondary source locations. + """ + + def __init__(self, level, reason, arguments, location, + highlights=None, notes=None): + if level not in self.LEVELS: + raise ValueError("level must be one of Diagnostic.LEVELS") + + if highlights is None: + highlights = [] + if notes is None: + notes = [] + + if len(set(map(lambda x: x.source_buffer, + [location] + highlights))) > 1: + raise ValueError("location and highlights must refer to the same source buffer") + + self.level, self.reason, self.arguments = \ + level, reason, arguments + self.location, self.highlights, self.notes = \ + location, highlights, notes + + def message(self): + """ + Returns the formatted message. + """ + return self.reason.format(**self.arguments) + + def render(self, only_line=False, colored=False): + """ + Returns the human-readable location of the diagnostic in the source, + the formatted message, the source line corresponding + to ``location`` and a line emphasizing the problematic + locations in the source line using ASCII art, as a list of lines. + Appends the result of calling :meth:`render` on ``notes``, if any. + + For example: :: + + :1:8-9: error: cannot add integer and string + x + (1 + "a") + ~ ^ ~~~ + + :param only_line: (bool) If true, only print line number, not line and column range + """ + source_line = self.location.source_line().rstrip("\n") + highlight_line = bytearray(re.sub(r"[^\t]", " ", source_line), "utf-8") + + for hilight in self.highlights: + if hilight.line() == self.location.line(): + lft, rgt = hilight.column_range() + highlight_line[lft:rgt] = bytearray("~", "utf-8") * (rgt - lft) + + lft, rgt = self.location.column_range() + if rgt == lft: # Expand zero-length ranges to one ^ + rgt = lft + 1 + highlight_line[lft:rgt] = bytearray("^", "utf-8") * (rgt - lft) + + if only_line: + location = "%s:%s" % (self.location.source_buffer.name, self.location.line()) + else: + location = str(self.location) + + notes = list(self.notes) + if self.level != "note": + expanded_location = self.location.expanded_from + while expanded_location is not None: + notes.insert(0, Diagnostic("note", + "expanded from here", {}, + self.location.expanded_from)) + expanded_location = expanded_location.expanded_from + + rendered_notes = reduce(list.__add__, [note.render(only_line, colored) + for note in notes], []) + if colored: + if self.level in ("error", "fatal"): + level_color = 31 # red + elif self.level == "warning": + level_color = 35 # magenta + else: # level == "note" + level_color = 30 # gray + return [ + "\x1b[1;37m{}: \x1b[{}m{}:\x1b[37m {}\x1b[0m". + format(location, level_color, self.level, self.message()), + source_line, + "\x1b[1;32m{}\x1b[0m".format(highlight_line.decode("utf-8")) + ] + rendered_notes + else: + return [ + "{}: {}: {}".format(location, self.level, self.message()), + source_line, + highlight_line.decode("utf-8") + ] + rendered_notes + + +class Error(Exception): + """ + :class:`Error` is an exception which carries a :class:`Diagnostic`. + + :ivar diagnostic: (:class:`Diagnostic`) the diagnostic + """ + def __init__(self, diagnostic): + self.diagnostic = diagnostic + + def __str__(self): + return "\n".join(self.diagnostic.render()) + +class Engine: + """ + :class:`Engine` is a single point through which diagnostics from + lexer, parser and any AST consumer are dispatched. + + :ivar all_errors_are_fatal: if true, an exception is raised not only + for ``fatal`` diagnostic level, but also ``error`` + """ + def __init__(self, all_errors_are_fatal=False): + self.all_errors_are_fatal = all_errors_are_fatal + self._appended_notes = [] + + def process(self, diagnostic): + """ + The default implementation of :meth:`process` renders non-fatal + diagnostics to ``sys.stderr``, and raises fatal ones as a :class:`Error`. + """ + diagnostic.notes += self._appended_notes + self.render_diagnostic(diagnostic) + if diagnostic.level == "fatal" or \ + (self.all_errors_are_fatal and diagnostic.level == "error"): + raise Error(diagnostic) + + @contextmanager + def context(self, *notes): + """ + A context manager that appends ``note`` to every diagnostic processed by + this engine. + """ + self._appended_notes += notes + yield + del self._appended_notes[-len(notes):] + + def render_diagnostic(self, diagnostic): + sys.stderr.write("\n".join(diagnostic.render()) + "\n") diff --git a/third_party/pythonparser/lexer.py b/third_party/pythonparser/lexer.py new file mode 100644 index 00000000..13c28155 --- /dev/null +++ b/third_party/pythonparser/lexer.py @@ -0,0 +1,611 @@ +""" +The :mod:`lexer` module concerns itself with tokenizing Python source. +""" + +from __future__ import absolute_import, division, print_function, unicode_literals +from . import source, diagnostic +import re +import unicodedata +import sys + +if sys.version_info[0] == 3: + unichr = chr + byte = lambda x: bytes([x]) +else: + byte = chr + +class Token: + """ + The :class:`Token` encapsulates a single lexer token and its location + in the source code. + + :ivar loc: (:class:`pythonparser.source.Range`) token location + :ivar kind: (string) token kind + :ivar value: token value; None or a kind-specific class + """ + def __init__(self, loc, kind, value=None): + self.loc, self.kind, self.value = loc, kind, value + + def __repr__(self): + return "Token(%s, \"%s\", %s)" % (repr(self.loc), self.kind, repr(self.value)) + +class Lexer: + """ + The :class:`Lexer` class extracts tokens and comments from + a :class:`pythonparser.source.Buffer`. + + :class:`Lexer` is an iterable. + + :ivar version: (tuple of (*major*, *minor*)) + the version of Python, determining the grammar used + :ivar source_buffer: (:class:`pythonparser.source.Buffer`) + the source buffer + :ivar diagnostic_engine: (:class:`pythonparser.diagnostic.Engine`) + the diagnostic engine + :ivar offset: (integer) character offset into ``source_buffer`` + indicating where the next token will be recognized + :ivar interactive: (boolean) whether a completely empty line + should generate a NEWLINE token, for use in REPLs + """ + + _reserved_2_6 = frozenset([ + "!=", "%", "%=", "&", "&=", "(", ")", "*", "**", "**=", "*=", "+", "+=", + ",", "-", "-=", ".", "/", "//", "//=", "/=", ":", ";", "<", "<<", "<<=", + "<=", "<>", "=", "==", ">", ">=", ">>", ">>=", "@", "[", "]", "^", "^=", "`", + "and", "as", "assert", "break", "class", "continue", "def", "del", "elif", + "else", "except", "exec", "finally", "for", "from", "global", "if", "import", + "in", "is", "lambda", "not", "or", "pass", "print", "raise", "return", "try", + "while", "with", "yield", "{", "|", "|=", "}", "~" + ]) + + _reserved_3_0 = _reserved_2_6 \ + - set(["<>", "`", "exec", "print"]) \ + | set(["->", "...", "False", "None", "nonlocal", "True"]) + + _reserved_3_1 = _reserved_3_0 \ + | set(["<>"]) + + _reserved_3_5 = _reserved_3_1 \ + | set(["@", "@="]) + + _reserved = { + (2, 6): _reserved_2_6, + (2, 7): _reserved_2_6, + (3, 0): _reserved_3_0, + (3, 1): _reserved_3_1, + (3, 2): _reserved_3_1, + (3, 3): _reserved_3_1, + (3, 4): _reserved_3_1, + (3, 5): _reserved_3_5, + } + """ + A map from a tuple (*major*, *minor*) corresponding to Python version to + :class:`frozenset`\s of keywords. + """ + + _string_prefixes_3_1 = frozenset(["", "r", "b", "br"]) + _string_prefixes_3_3 = frozenset(["", "r", "u", "b", "br", "rb"]) + + # holy mother of god why + _string_prefixes = { + (2, 6): frozenset(["", "r", "u", "ur"]), + (2, 7): frozenset(["", "r", "u", "ur", "b", "br"]), + (3, 0): frozenset(["", "r", "b"]), + (3, 1): _string_prefixes_3_1, + (3, 2): _string_prefixes_3_1, + (3, 3): _string_prefixes_3_3, + (3, 4): _string_prefixes_3_3, + (3, 5): _string_prefixes_3_3, + } + """ + A map from a tuple (*major*, *minor*) corresponding to Python version to + :class:`frozenset`\s of string prefixes. + """ + + def __init__(self, source_buffer, version, diagnostic_engine, interactive=False): + self.source_buffer = source_buffer + self.version = version + self.diagnostic_engine = diagnostic_engine + self.interactive = interactive + self.print_function = False + self.unicode_literals = self.version >= (3, 0) + + self.offset = 0 + self.new_line = True + self.indent = [(0, source.Range(source_buffer, 0, 0), "")] + self.comments = [] + self.queue = [] + self.parentheses = [] + self.curly_braces = [] + self.square_braces = [] + + try: + reserved = self._reserved[version] + except KeyError: + raise NotImplementedError("pythonparser.lexer.Lexer cannot lex Python %s" % str(version)) + + # Sort for the regexp to obey longest-match rule. + re_reserved = sorted(reserved, reverse=True, key=len) + re_keywords = "|".join([kw for kw in re_reserved if kw.isalnum()]) + re_operators = "|".join([re.escape(op) for op in re_reserved if not op.isalnum()]) + + # Python 3.0 uses ID_Start, >3.0 uses XID_Start + if self.version == (3, 0): + id_xid = "" + else: + id_xid = "X" + + # To speed things up on CPython, we use the re module to generate a DFA + # from our token set and execute it in C. Every result yielded by + # iterating this regular expression has exactly one non-empty group + # that would correspond to a e.g. lex scanner branch. + # The only thing left to Python code is then to select one from this + # small set of groups, which is much faster than dissecting the strings. + # + # A lexer has to obey longest-match rule, but a regular expression does not. + # Therefore, the cases in it are carefully sorted so that the longest + # ones come up first. The exception is the identifier case, which would + # otherwise grab all keywords; it is made to work by making it impossible + # for the keyword case to match a word prefix, and ordering it before + # the identifier case. + self._lex_token_re = re.compile(r""" + [ \t\f]* # initial whitespace + ( # 1 + (\\)? # ?2 line continuation + ([\n]|[\r][\n]|[\r]) # 3 newline + | (\#.*) # 4 comment + | ( # 5 floating point or complex literal + (?: [0-9]* \. [0-9]+ + | [0-9]+ \.? + ) [eE] [+-]? [0-9]+ + | [0-9]* \. [0-9]+ + | [0-9]+ \. + ) ([jJ])? # ?6 complex suffix + | ([0-9]+) [jJ] # 7 complex literal + | (?: # integer literal + ( [1-9] [0-9]* ) # 8 dec + | 0[oO] ( [0-7]+ ) # 9 oct + | 0[xX] ( [0-9A-Fa-f]+ ) # 10 hex + | 0[bB] ( [01]+ ) # 11 bin + | ( [0-9] [0-9]* ) # 12 bare oct + ) + ([Ll])? # 13 long option + | ([BbUu]?[Rr]?) # ?14 string literal options + (?: # string literal start + # 15, 16, 17 long string + (""\"|''') ((?: \\?[\n] | \\. | . )*?) (\15) + # 18, 19, 20 short string + | (" |' ) ((?: \\ [\n] | \\. | . )*?) (\18) + # 21 unterminated + | (""\"|'''|"|') + ) + | ((?:{keywords})\b|{operators}) # 22 keywords and operators + | ([A-Za-z_][A-Za-z0-9_]*\b) # 23 identifier + | (\p{{{id_xid}ID_Start}}\p{{{id_xid}ID_Continue}}*) # 24 Unicode identifier + | ($) # 25 end-of-file + ) + """.format(keywords=re_keywords, operators=re_operators, + id_xid=id_xid), re.VERBOSE|re.UNICODE) + + # These are identical for all lexer instances. + _lex_escape_pattern = r""" + \\(?: + ([\n\\'"abfnrtv]) # 1 single-char + | ([0-7]{1,3}) # 2 oct + | x([0-9A-Fa-f]{2}) # 3 hex + ) + """ + _lex_escape_re = re.compile(_lex_escape_pattern.encode(), re.VERBOSE) + + _lex_escape_unicode_re = re.compile(_lex_escape_pattern + r""" + | \\(?: + u([0-9A-Fa-f]{4}) # 4 unicode-16 + | U([0-9A-Fa-f]{8}) # 5 unicode-32 + | N\{(.+?)\} # 6 unicode-name + ) + """, re.VERBOSE) + + def next(self, eof_token=False): + """ + Returns token at ``offset`` as a :class:`Token` and advances ``offset`` + to point past the end of the token, where the token has: + + - *range* which is a :class:`pythonparser.source.Range` that includes + the token but not surrounding whitespace, + - *kind* which is a string containing one of Python keywords or operators, + ``newline``, ``float``, ``int``, ``complex``, ``strbegin``, + ``strdata``, ``strend``, ``ident``, ``indent``, ``dedent`` or ``eof`` + (if ``eof_token`` is True). + - *value* which is the flags as lowercase string if *kind* is ``strbegin``, + the string contents if *kind* is ``strdata``, + the numeric value if *kind* is ``float``, ``int`` or ``complex``, + the identifier if *kind* is ``ident`` and ``None`` in any other case. + + :param eof_token: if true, will return a token with kind ``eof`` + when the input is exhausted; if false, will raise ``StopIteration``. + """ + if len(self.queue) == 0: + self._refill(eof_token) + + return self.queue.pop(0) + + def peek(self, eof_token=False): + """Same as :meth:`next`, except the token is not dequeued.""" + if len(self.queue) == 0: + self._refill(eof_token) + + return self.queue[-1] + + # We need separate next and _refill because lexing can sometimes + # generate several tokens, e.g. INDENT + def _refill(self, eof_token): + if self.offset == len(self.source_buffer.source): + range = source.Range(self.source_buffer, self.offset, self.offset) + + if not self.new_line: + self.new_line = True + self.queue.append(Token(range, "newline")) + return + + for i in self.indent[1:]: + self.indent.pop(-1) + self.queue.append(Token(range, "dedent")) + + if eof_token: + self.queue.append(Token(range, "eof")) + elif len(self.queue) == 0: + raise StopIteration + + return + + match = self._lex_token_re.match(self.source_buffer.source, self.offset) + if match is None: + diag = diagnostic.Diagnostic( + "fatal", "unexpected {character}", + {"character": repr(self.source_buffer.source[self.offset]).lstrip("u")}, + source.Range(self.source_buffer, self.offset, self.offset + 1)) + self.diagnostic_engine.process(diag) + + # Should we emit indent/dedent? + if self.new_line and \ + match.group(3) is None and \ + match.group(4) is None: # not a blank line + whitespace = match.string[match.start(0):match.start(1)] + level = len(whitespace.expandtabs()) + range = source.Range(self.source_buffer, match.start(1), match.start(1)) + if level > self.indent[-1][0]: + self.indent.append((level, range, whitespace)) + self.queue.append(Token(range, "indent")) + elif level < self.indent[-1][0]: + exact = False + while level <= self.indent[-1][0]: + if level == self.indent[-1][0] or self.indent[-1][0] == 0: + exact = True + break + self.indent.pop(-1) + self.queue.append(Token(range, "dedent")) + if not exact: + note = diagnostic.Diagnostic( + "note", "expected to match level here", {}, + self.indent[-1][1]) + error = diagnostic.Diagnostic( + "fatal", "inconsistent indentation", {}, + range, notes=[note]) + self.diagnostic_engine.process(error) + elif whitespace != self.indent[-1][2] and self.version >= (3, 0): + error = diagnostic.Diagnostic( + "error", "inconsistent use of tabs and spaces in indentation", {}, + range) + self.diagnostic_engine.process(error) + + # Prepare for next token. + self.offset = match.end(0) + + tok_range = source.Range(self.source_buffer, *match.span(1)) + if match.group(3) is not None: # newline + if len(self.parentheses) + len(self.square_braces) + len(self.curly_braces) > 0: + # 2.1.6 Implicit line joining + return self._refill(eof_token) + if match.group(2) is not None: + # 2.1.5. Explicit line joining + return self._refill(eof_token) + if self.new_line and not \ + (self.interactive and match.group(0) == match.group(3)): # REPL terminator + # 2.1.7. Blank lines + return self._refill(eof_token) + + self.new_line = True + self.queue.append(Token(tok_range, "newline")) + return + + if match.group(4) is not None: # comment + self.comments.append(source.Comment(tok_range, match.group(4))) + return self._refill(eof_token) + + # Lexing non-whitespace now. + self.new_line = False + + if sys.version_info > (3,) or not match.group(13): + int_type = int + else: + int_type = long + + if match.group(5) is not None: # floating point or complex literal + if match.group(6) is None: + self.queue.append(Token(tok_range, "float", float(match.group(5)))) + else: + self.queue.append(Token(tok_range, "complex", float(match.group(5)) * 1j)) + + elif match.group(7) is not None: # complex literal + self.queue.append(Token(tok_range, "complex", int(match.group(7)) * 1j)) + + elif match.group(8) is not None: # integer literal, dec + literal = match.group(8) + self._check_long_literal(tok_range, match.group(1)) + self.queue.append(Token(tok_range, "int", int_type(literal))) + + elif match.group(9) is not None: # integer literal, oct + literal = match.group(9) + self._check_long_literal(tok_range, match.group(1)) + self.queue.append(Token(tok_range, "int", int_type(literal, 8))) + + elif match.group(10) is not None: # integer literal, hex + literal = match.group(10) + self._check_long_literal(tok_range, match.group(1)) + self.queue.append(Token(tok_range, "int", int_type(literal, 16))) + + elif match.group(11) is not None: # integer literal, bin + literal = match.group(11) + self._check_long_literal(tok_range, match.group(1)) + self.queue.append(Token(tok_range, "int", int_type(literal, 2))) + + elif match.group(12) is not None: # integer literal, bare oct + literal = match.group(12) + if len(literal) > 1 and self.version >= (3, 0): + error = diagnostic.Diagnostic( + "error", "in Python 3, decimal literals must not start with a zero", {}, + source.Range(self.source_buffer, tok_range.begin_pos, tok_range.begin_pos + 1)) + self.diagnostic_engine.process(error) + self.queue.append(Token(tok_range, "int", int(literal, 8))) + + elif match.group(15) is not None: # long string literal + self._string_literal( + options=match.group(14), begin_span=(match.start(14), match.end(15)), + data=match.group(16), data_span=match.span(16), + end_span=match.span(17)) + + elif match.group(18) is not None: # short string literal + self._string_literal( + options=match.group(14), begin_span=(match.start(14), match.end(18)), + data=match.group(19), data_span=match.span(19), + end_span=match.span(20)) + + elif match.group(21) is not None: # unterminated string + error = diagnostic.Diagnostic( + "fatal", "unterminated string", {}, + tok_range) + self.diagnostic_engine.process(error) + + elif match.group(22) is not None: # keywords and operators + kwop = match.group(22) + self._match_pair_delim(tok_range, kwop) + if kwop == "print" and self.print_function: + self.queue.append(Token(tok_range, "ident", "print")) + else: + self.queue.append(Token(tok_range, kwop)) + + elif match.group(23) is not None: # identifier + self.queue.append(Token(tok_range, "ident", match.group(23))) + + elif match.group(24) is not None: # Unicode identifier + if self.version < (3, 0): + error = diagnostic.Diagnostic( + "error", "in Python 2, Unicode identifiers are not allowed", {}, + tok_range) + self.diagnostic_engine.process(error) + self.queue.append(Token(tok_range, "ident", match.group(24))) + + elif match.group(25) is not None: # end-of-file + # Reuse the EOF logic + return self._refill(eof_token) + + else: + assert False + + def _string_literal(self, options, begin_span, data, data_span, end_span): + options = options.lower() + begin_range = source.Range(self.source_buffer, *begin_span) + data_range = source.Range(self.source_buffer, *data_span) + + if options not in self._string_prefixes[self.version]: + error = diagnostic.Diagnostic( + "error", "string prefix '{prefix}' is not available in Python {major}.{minor}", + {"prefix": options, "major": self.version[0], "minor": self.version[1]}, + begin_range) + self.diagnostic_engine.process(error) + + self.queue.append(Token(begin_range, "strbegin", options)) + self.queue.append(Token(data_range, + "strdata", self._replace_escape(data_range, options, data))) + self.queue.append(Token(source.Range(self.source_buffer, *end_span), + "strend")) + + def _replace_escape(self, range, mode, value): + is_raw = ("r" in mode) + is_unicode = "u" in mode or ("b" not in mode and self.unicode_literals) + + if not is_unicode: + value = value.encode(self.source_buffer.encoding) + if is_raw: + return value + return self._replace_escape_bytes(value) + + if is_raw: + return value + + return self._replace_escape_unicode(range, value) + + def _replace_escape_unicode(self, range, value): + chunks = [] + offset = 0 + while offset < len(value): + match = self._lex_escape_unicode_re.search(value, offset) + if match is None: + # Append the remaining of the string + chunks.append(value[offset:]) + break + + # Append the part of string before match + chunks.append(value[offset:match.start()]) + offset = match.end() + + # Process the escape + if match.group(1) is not None: # single-char + chr = match.group(1) + if chr == "\n": + pass + elif chr == "\\" or chr == "'" or chr == "\"": + chunks.append(chr) + elif chr == "a": + chunks.append("\a") + elif chr == "b": + chunks.append("\b") + elif chr == "f": + chunks.append("\f") + elif chr == "n": + chunks.append("\n") + elif chr == "r": + chunks.append("\r") + elif chr == "t": + chunks.append("\t") + elif chr == "v": + chunks.append("\v") + elif match.group(2) is not None: # oct + chunks.append(unichr(int(match.group(2), 8))) + elif match.group(3) is not None: # hex + chunks.append(unichr(int(match.group(3), 16))) + elif match.group(4) is not None: # unicode-16 + chunks.append(unichr(int(match.group(4), 16))) + elif match.group(5) is not None: # unicode-32 + try: + chunks.append(unichr(int(match.group(5), 16))) + except ValueError: + error = diagnostic.Diagnostic( + "error", "unicode character out of range", {}, + source.Range(self.source_buffer, + range.begin_pos + match.start(0), + range.begin_pos + match.end(0))) + self.diagnostic_engine.process(error) + elif match.group(6) is not None: # unicode-name + try: + chunks.append(unicodedata.lookup(match.group(6))) + except KeyError: + error = diagnostic.Diagnostic( + "error", "unknown unicode character name", {}, + source.Range(self.source_buffer, + range.begin_pos + match.start(0), + range.begin_pos + match.end(0))) + self.diagnostic_engine.process(error) + + return "".join(chunks) + + def _replace_escape_bytes(self, value): + chunks = [] + offset = 0 + while offset < len(value): + match = self._lex_escape_re.search(value, offset) + if match is None: + # Append the remaining of the string + chunks.append(value[offset:]) + break + + # Append the part of string before match + chunks.append(value[offset:match.start()]) + offset = match.end() + + # Process the escape + if match.group(1) is not None: # single-char + chr = match.group(1) + if chr == b"\n": + pass + elif chr == b"\\" or chr == b"'" or chr == b"\"": + chunks.append(chr) + elif chr == b"a": + chunks.append(b"\a") + elif chr == b"b": + chunks.append(b"\b") + elif chr == b"f": + chunks.append(b"\f") + elif chr == b"n": + chunks.append(b"\n") + elif chr == b"r": + chunks.append(b"\r") + elif chr == b"t": + chunks.append(b"\t") + elif chr == b"v": + chunks.append(b"\v") + elif match.group(2) is not None: # oct + chunks.append(byte(int(match.group(2), 8))) + elif match.group(3) is not None: # hex + chunks.append(byte(int(match.group(3), 16))) + + return b"".join(chunks) + + def _check_long_literal(self, range, literal): + if literal[-1] in "lL" and self.version >= (3, 0): + error = diagnostic.Diagnostic( + "error", "in Python 3, long integer literals were removed", {}, + source.Range(self.source_buffer, range.end_pos - 1, range.end_pos)) + self.diagnostic_engine.process(error) + + def _match_pair_delim(self, range, kwop): + if kwop == "(": + self.parentheses.append(range) + elif kwop == "[": + self.square_braces.append(range) + elif kwop == "{": + self.curly_braces.append(range) + elif kwop == ")": + self._check_innermost_pair_delim(range, "(") + self.parentheses.pop() + elif kwop == "]": + self._check_innermost_pair_delim(range, "[") + self.square_braces.pop() + elif kwop == "}": + self._check_innermost_pair_delim(range, "{") + self.curly_braces.pop() + + def _check_innermost_pair_delim(self, range, expected): + ranges = [] + if len(self.parentheses) > 0: + ranges.append(("(", self.parentheses[-1])) + if len(self.square_braces) > 0: + ranges.append(("[", self.square_braces[-1])) + if len(self.curly_braces) > 0: + ranges.append(("{", self.curly_braces[-1])) + + ranges.sort(key=lambda k: k[1].begin_pos) + if any(ranges): + compl_kind, compl_range = ranges[-1] + if compl_kind != expected: + note = diagnostic.Diagnostic( + "note", "'{delimiter}' opened here", + {"delimiter": compl_kind}, + compl_range) + error = diagnostic.Diagnostic( + "fatal", "mismatched '{delimiter}'", + {"delimiter": range.source()}, + range, notes=[note]) + self.diagnostic_engine.process(error) + else: + error = diagnostic.Diagnostic( + "fatal", "mismatched '{delimiter}'", + {"delimiter": range.source()}, + range) + self.diagnostic_engine.process(error) + + def __iter__(self): + return self + + def __next__(self): + return self.next() diff --git a/third_party/pythonparser/parser.py b/third_party/pythonparser/parser.py new file mode 100644 index 00000000..13995c5e --- /dev/null +++ b/third_party/pythonparser/parser.py @@ -0,0 +1,2017 @@ +# encoding:utf-8 + +""" +The :mod:`parser` module concerns itself with parsing Python source. +""" + +from __future__ import absolute_import, division, print_function, unicode_literals +from functools import reduce +from . import source, diagnostic, lexer, ast + +# A few notes about our approach to parsing: +# +# Python uses an LL(1) parser generator. It's a bit weird, because +# the usual reason to choose LL(1) is to make a handwritten parser +# possible, however Python's grammar is formulated in a way that +# is much more easily recognized if you make an FSM rather than +# the usual "if accept(token)..." ladder. So in a way it is +# the worst of both worlds. +# +# We don't use a parser generator because we want to have an unified +# grammar for all Python versions, and also have grammar coverage +# analysis and nice error recovery. To make the grammar compact, +# we use combinators to compose it from predefined fragments, +# such as "sequence" or "alternation" or "Kleene star". This easily +# gives us one token of lookahead in most cases, but e.g. not +# in the following one: +# +# argument: test | test '=' test +# +# There are two issues with this. First, in an alternation, the first +# variant will be tried (and accepted) earlier. Second, if we reverse +# them, by the point it is clear ``'='`` will not be accepted, ``test`` +# has already been consumed. +# +# The way we fix this is by reordering rules so that longest match +# comes first, and adding backtracking on alternations (as well as +# plus and star, since those have a hidden alternation inside). +# +# While backtracking can in principle make asymptotical complexity +# worse, it never makes parsing syntactically correct code supralinear +# with Python's LL(1) grammar, and we could not come up with any +# pathological incorrect input as well. + +# Coverage data +_all_rules = [] +_all_stmts = {} + +# Generic LL parsing combinators +class Unmatched: + pass + +unmatched = Unmatched() + +def llrule(loc, expected, cases=1): + if loc is None: + def decorator(rule): + rule.expected = expected + return rule + else: + def decorator(inner_rule): + if cases == 1: + def rule(*args, **kwargs): + result = inner_rule(*args, **kwargs) + if result is not unmatched: + rule.covered[0] = True + return result + else: + rule = inner_rule + + rule.loc, rule.expected, rule.covered = \ + loc, expected, [False] * cases + _all_rules.append(rule) + + return rule + return decorator + +def action(inner_rule, loc=None): + """ + A decorator returning a function that first runs ``inner_rule`` and then, if its + return value is not None, maps that value using ``mapper``. + + If the value being mapped is a tuple, it is expanded into multiple arguments. + + Similar to attaching semantic actions to rules in traditional parser generators. + """ + def decorator(mapper): + @llrule(loc, inner_rule.expected) + def outer_rule(parser): + result = inner_rule(parser) + if result is unmatched: + return result + if isinstance(result, tuple): + return mapper(parser, *result) + else: + return mapper(parser, result) + return outer_rule + return decorator + +def Eps(value=None, loc=None): + """A rule that accepts no tokens (epsilon) and returns ``value``.""" + @llrule(loc, lambda parser: []) + def rule(parser): + return value + return rule + +def Tok(kind, loc=None): + """A rule that accepts a token of kind ``kind`` and returns it, or returns None.""" + @llrule(loc, lambda parser: [kind]) + def rule(parser): + return parser._accept(kind) + return rule + +def Loc(kind, loc=None): + """A rule that accepts a token of kind ``kind`` and returns its location, or returns None.""" + @llrule(loc, lambda parser: [kind]) + def rule(parser): + result = parser._accept(kind) + if result is unmatched: + return result + return result.loc + return rule + +def Rule(name, loc=None): + """A proxy for a rule called ``name`` which may not be yet defined.""" + @llrule(loc, lambda parser: getattr(parser, name).expected(parser)) + def rule(parser): + return getattr(parser, name)() + return rule + +def Expect(inner_rule, loc=None): + """A rule that executes ``inner_rule`` and emits a diagnostic error if it returns None.""" + @llrule(loc, inner_rule.expected) + def rule(parser): + result = inner_rule(parser) + if result is unmatched: + expected = reduce(list.__add__, [rule.expected(parser) for rule in parser._errrules]) + expected = list(sorted(set(expected))) + + if len(expected) > 1: + expected = " or ".join([", ".join(expected[0:-1]), expected[-1]]) + elif len(expected) == 1: + expected = expected[0] + else: + expected = "(impossible)" + + error_tok = parser._tokens[parser._errindex] + error = diagnostic.Diagnostic( + "fatal", "unexpected {actual}: expected {expected}", + {"actual": error_tok.kind, "expected": expected}, + error_tok.loc) + parser.diagnostic_engine.process(error) + return result + return rule + +def Seq(first_rule, *rest_of_rules, **kwargs): + """ + A rule that accepts a sequence of tokens satisfying ``rules`` and returns a tuple + containing their return values, or None if the first rule was not satisfied. + """ + @llrule(kwargs.get("loc", None), first_rule.expected) + def rule(parser): + result = first_rule(parser) + if result is unmatched: + return result + + results = [result] + for rule in rest_of_rules: + result = rule(parser) + if result is unmatched: + return result + results.append(result) + return tuple(results) + return rule + +def SeqN(n, *inner_rules, **kwargs): + """ + A rule that accepts a sequence of tokens satisfying ``rules`` and returns + the value returned by rule number ``n``, or None if the first rule was not satisfied. + """ + @action(Seq(*inner_rules), loc=kwargs.get("loc", None)) + def rule(parser, *values): + return values[n] + return rule + +def Alt(*inner_rules, **kwargs): + """ + A rule that expects a sequence of tokens satisfying one of ``rules`` in sequence + (a rule is satisfied when it returns anything but None) and returns the return + value of that rule, or None if no rules were satisfied. + """ + loc = kwargs.get("loc", None) + expected = lambda parser: reduce(list.__add__, map(lambda x: x.expected(parser), inner_rules)) + if loc is not None: + @llrule(loc, expected, cases=len(inner_rules)) + def rule(parser): + data = parser._save() + for idx, inner_rule in enumerate(inner_rules): + result = inner_rule(parser) + if result is unmatched: + parser._restore(data, rule=inner_rule) + else: + rule.covered[idx] = True + return result + return unmatched + else: + @llrule(loc, expected, cases=len(inner_rules)) + def rule(parser): + data = parser._save() + for inner_rule in inner_rules: + result = inner_rule(parser) + if result is unmatched: + parser._restore(data, rule=inner_rule) + else: + return result + return unmatched + return rule + +def Opt(inner_rule, loc=None): + """Shorthand for ``Alt(inner_rule, Eps())``""" + return Alt(inner_rule, Eps(), loc=loc) + +def Star(inner_rule, loc=None): + """ + A rule that accepts a sequence of tokens satisfying ``inner_rule`` zero or more times, + and returns the returned values in a :class:`list`. + """ + @llrule(loc, lambda parser: []) + def rule(parser): + results = [] + while True: + data = parser._save() + result = inner_rule(parser) + if result is unmatched: + parser._restore(data, rule=inner_rule) + return results + results.append(result) + return rule + +def Plus(inner_rule, loc=None): + """ + A rule that accepts a sequence of tokens satisfying ``inner_rule`` one or more times, + and returns the returned values in a :class:`list`. + """ + @llrule(loc, inner_rule.expected) + def rule(parser): + result = inner_rule(parser) + if result is unmatched: + return result + + results = [result] + while True: + data = parser._save() + result = inner_rule(parser) + if result is unmatched: + parser._restore(data, rule=inner_rule) + return results + results.append(result) + return rule + +class commalist(list): + __slots__ = ("trailing_comma",) + +def List(inner_rule, separator_tok, trailing, leading=True, loc=None): + if not trailing: + @action(Seq(inner_rule, Star(SeqN(1, Tok(separator_tok), inner_rule))), loc=loc) + def outer_rule(parser, first, rest): + return [first] + rest + return outer_rule + else: + # A rule like this: stmt (';' stmt)* [';'] + # This doesn't yield itself to combinators above, because disambiguating + # another iteration of the Kleene star and the trailing separator + # requires two lookahead tokens (naively). + separator_rule = Tok(separator_tok) + @llrule(loc, inner_rule.expected) + def rule(parser): + results = commalist() + + if leading: + result = inner_rule(parser) + if result is unmatched: + return result + else: + results.append(result) + + while True: + result = separator_rule(parser) + if result is unmatched: + results.trailing_comma = None + return results + + result_1 = inner_rule(parser) + if result_1 is unmatched: + results.trailing_comma = result + return results + else: + results.append(result_1) + return rule + +# Python AST specific parser combinators +def Newline(loc=None): + """A rule that accepts token of kind ``newline`` and returns an empty list.""" + @llrule(loc, lambda parser: ["newline"]) + def rule(parser): + result = parser._accept("newline") + if result is unmatched: + return result + return [] + return rule + +def Oper(klass, *kinds, **kwargs): + """ + A rule that accepts a sequence of tokens of kinds ``kinds`` and returns + an instance of ``klass`` with ``loc`` encompassing the entire sequence + or None if the first token is not of ``kinds[0]``. + """ + @action(Seq(*map(Loc, kinds)), loc=kwargs.get("loc", None)) + def rule(parser, *tokens): + return klass(loc=tokens[0].join(tokens[-1])) + return rule + +def BinOper(expr_rulename, op_rule, node=ast.BinOp, loc=None): + @action(Seq(Rule(expr_rulename), Star(Seq(op_rule, Rule(expr_rulename)))), loc=loc) + def rule(parser, lhs, trailers): + for (op, rhs) in trailers: + lhs = node(left=lhs, op=op, right=rhs, + loc=lhs.loc.join(rhs.loc)) + return lhs + return rule + +def BeginEnd(begin_tok, inner_rule, end_tok, empty=None, loc=None): + @action(Seq(Loc(begin_tok), inner_rule, Loc(end_tok)), loc=loc) + def rule(parser, begin_loc, node, end_loc): + if node is None: + node = empty(parser) + + # Collection nodes don't have loc yet. If a node has loc at this + # point, it means it's an expression passed in parentheses. + if node.loc is None and type(node) in [ + ast.List, ast.ListComp, + ast.Dict, ast.DictComp, + ast.Set, ast.SetComp, + ast.GeneratorExp, + ast.Tuple, ast.Repr, + ast.Call, ast.Subscript, + ast.arguments]: + node.begin_loc, node.end_loc, node.loc = \ + begin_loc, end_loc, begin_loc.join(end_loc) + return node + return rule + +class Parser: + + # Generic LL parsing methods + def __init__(self, lexer, version, diagnostic_engine): + self._init_version(version) + self.diagnostic_engine = diagnostic_engine + + self.lexer = lexer + self._tokens = [] + self._index = -1 + self._errindex = -1 + self._errrules = [] + self._advance() + + def _save(self): + return self._index + + def _restore(self, data, rule): + self._index = data + self._token = self._tokens[self._index] + + if self._index > self._errindex: + # We have advanced since last error + self._errindex = self._index + self._errrules = [rule] + elif self._index == self._errindex: + # We're at the same place as last error + self._errrules.append(rule) + else: + # We've backtracked far and are now just failing the + # whole parse + pass + + def _advance(self): + self._index += 1 + if self._index == len(self._tokens): + self._tokens.append(self.lexer.next(eof_token=True)) + self._token = self._tokens[self._index] + + def _accept(self, expected_kind): + if self._token.kind == expected_kind: + result = self._token + self._advance() + return result + return unmatched + + # Python-specific methods + def _init_version(self, version): + if version in ((2, 6), (2, 7)): + if version == (2, 6): + self.with_stmt = self.with_stmt__26 + self.atom_6 = self.atom_6__26 + else: + self.with_stmt = self.with_stmt__27 + self.atom_6 = self.atom_6__27 + self.except_clause_1 = self.except_clause_1__26 + self.classdef = self.classdef__26 + self.subscript = self.subscript__26 + self.raise_stmt = self.raise_stmt__26 + self.comp_if = self.comp_if__26 + self.atom = self.atom__26 + self.funcdef = self.funcdef__26 + self.parameters = self.parameters__26 + self.varargslist = self.varargslist__26 + self.comparison_1 = self.comparison_1__26 + self.exprlist_1 = self.exprlist_1__26 + self.testlist_comp_1 = self.testlist_comp_1__26 + self.expr_stmt_1 = self.expr_stmt_1__26 + self.yield_expr = self.yield_expr__26 + return + elif version in ((3, 0), (3, 1), (3, 2), (3, 3), (3, 4), (3, 5)): + if version == (3, 0): + self.with_stmt = self.with_stmt__26 # lol + else: + self.with_stmt = self.with_stmt__27 + self.except_clause_1 = self.except_clause_1__30 + self.classdef = self.classdef__30 + self.subscript = self.subscript__30 + self.raise_stmt = self.raise_stmt__30 + self.comp_if = self.comp_if__30 + self.atom = self.atom__30 + self.funcdef = self.funcdef__30 + self.parameters = self.parameters__30 + if version < (3, 2): + self.varargslist = self.varargslist__30 + self.typedargslist = self.typedargslist__30 + self.comparison_1 = self.comparison_1__30 + self.star_expr = self.star_expr__30 + self.exprlist_1 = self.exprlist_1__30 + self.testlist_comp_1 = self.testlist_comp_1__26 + self.expr_stmt_1 = self.expr_stmt_1__26 + else: + self.varargslist = self.varargslist__32 + self.typedargslist = self.typedargslist__32 + self.comparison_1 = self.comparison_1__32 + self.star_expr = self.star_expr__32 + self.exprlist_1 = self.exprlist_1__32 + self.testlist_comp_1 = self.testlist_comp_1__32 + self.expr_stmt_1 = self.expr_stmt_1__32 + if version < (3, 3): + self.yield_expr = self.yield_expr__26 + else: + self.yield_expr = self.yield_expr__33 + return + + raise NotImplementedError("pythonparser.parser.Parser cannot parse Python %s" % + str(version)) + + def _arguments(self, args=None, defaults=None, kwonlyargs=None, kw_defaults=None, + vararg=None, kwarg=None, + star_loc=None, dstar_loc=None, begin_loc=None, end_loc=None, + equals_locs=None, kw_equals_locs=None, loc=None): + if args is None: + args = [] + if defaults is None: + defaults = [] + if kwonlyargs is None: + kwonlyargs = [] + if kw_defaults is None: + kw_defaults = [] + if equals_locs is None: + equals_locs = [] + if kw_equals_locs is None: + kw_equals_locs = [] + return ast.arguments(args=args, defaults=defaults, + kwonlyargs=kwonlyargs, kw_defaults=kw_defaults, + vararg=vararg, kwarg=kwarg, + star_loc=star_loc, dstar_loc=dstar_loc, + begin_loc=begin_loc, end_loc=end_loc, + equals_locs=equals_locs, kw_equals_locs=kw_equals_locs, + loc=loc) + + def _arg(self, tok, colon_loc=None, annotation=None): + loc = tok.loc + if annotation: + loc = loc.join(annotation.loc) + return ast.arg(arg=tok.value, annotation=annotation, + arg_loc=tok.loc, colon_loc=colon_loc, loc=loc) + + def _empty_arglist(self): + return ast.Call(args=[], keywords=[], starargs=None, kwargs=None, + star_loc=None, dstar_loc=None, loc=None) + + def _wrap_tuple(self, elts): + assert len(elts) > 0 + if len(elts) > 1: + return ast.Tuple(ctx=None, elts=elts, + loc=elts[0].loc.join(elts[-1].loc), begin_loc=None, end_loc=None) + else: + return elts[0] + + def _assignable(self, node, is_delete=False): + if isinstance(node, ast.Name) or isinstance(node, ast.Subscript) or \ + isinstance(node, ast.Attribute) or isinstance(node, ast.Starred): + return node + elif (isinstance(node, ast.List) or isinstance(node, ast.Tuple)) and \ + any(node.elts): + node.elts = [self._assignable(elt, is_delete) for elt in node.elts] + return node + else: + if is_delete: + error = diagnostic.Diagnostic( + "fatal", "cannot delete this expression", {}, node.loc) + else: + error = diagnostic.Diagnostic( + "fatal", "cannot assign to this expression", {}, node.loc) + self.diagnostic_engine.process(error) + + def add_flags(self, flags): + if "print_function" in flags: + self.lexer.print_function = True + if "unicode_literals" in flags: + self.lexer.unicode_literals = True + + # Grammar + @action(Expect(Alt(Newline(), + Rule("simple_stmt"), + SeqN(0, Rule("compound_stmt"), Newline())))) + def single_input(self, body): + """single_input: NEWLINE | simple_stmt | compound_stmt NEWLINE""" + loc = None + if body != []: + loc = body[0].loc + return ast.Interactive(body=body, loc=loc) + + @action(Expect(SeqN(0, Star(Alt(Newline(), Rule("stmt"))), Tok("eof")))) + def file_input(parser, body): + """file_input: (NEWLINE | stmt)* ENDMARKER""" + body = reduce(list.__add__, body, []) + loc = None + if body != []: + loc = body[0].loc + return ast.Module(body=body, loc=loc) + + @action(Expect(SeqN(0, Rule("testlist"), Star(Tok("newline")), Tok("eof")))) + def eval_input(self, expr): + """eval_input: testlist NEWLINE* ENDMARKER""" + return ast.Expression(body=[expr], loc=expr.loc) + + @action(Seq(Loc("@"), List(Tok("ident"), ".", trailing=False), + Opt(BeginEnd("(", Opt(Rule("arglist")), ")", + empty=_empty_arglist)), + Loc("newline"))) + def decorator(self, at_loc, idents, call_opt, newline_loc): + """decorator: '@' dotted_name [ '(' [arglist] ')' ] NEWLINE""" + root = idents[0] + dec_loc = root.loc + expr = ast.Name(id=root.value, ctx=None, loc=root.loc) + for ident in idents[1:]: + dot_loc = ident.loc.begin() + dot_loc.begin_pos -= 1 + dec_loc = dec_loc.join(ident.loc) + expr = ast.Attribute(value=expr, attr=ident.value, ctx=None, + loc=expr.loc.join(ident.loc), + attr_loc=ident.loc, dot_loc=dot_loc) + + if call_opt: + call_opt.func = expr + call_opt.loc = dec_loc.join(call_opt.loc) + expr = call_opt + return at_loc, expr + + decorators = Plus(Rule("decorator")) + """decorators: decorator+""" + + @action(Seq(Rule("decorators"), Alt(Rule("classdef"), Rule("funcdef")))) + def decorated(self, decorators, classfuncdef): + """decorated: decorators (classdef | funcdef)""" + classfuncdef.at_locs = list(map(lambda x: x[0], decorators)) + classfuncdef.decorator_list = list(map(lambda x: x[1], decorators)) + classfuncdef.loc = classfuncdef.loc.join(decorators[0][0]) + return classfuncdef + + @action(Seq(Loc("def"), Tok("ident"), Rule("parameters"), Loc(":"), Rule("suite"))) + def funcdef__26(self, def_loc, ident_tok, args, colon_loc, suite): + """(2.6, 2.7) funcdef: 'def' NAME parameters ':' suite""" + return ast.FunctionDef(name=ident_tok.value, args=args, returns=None, + body=suite, decorator_list=[], + at_locs=[], keyword_loc=def_loc, name_loc=ident_tok.loc, + colon_loc=colon_loc, arrow_loc=None, + loc=def_loc.join(suite[-1].loc)) + + @action(Seq(Loc("def"), Tok("ident"), Rule("parameters"), + Opt(Seq(Loc("->"), Rule("test"))), + Loc(":"), Rule("suite"))) + def funcdef__30(self, def_loc, ident_tok, args, returns_opt, colon_loc, suite): + """(3.0-) funcdef: 'def' NAME parameters ['->' test] ':' suite""" + arrow_loc = returns = None + if returns_opt: + arrow_loc, returns = returns_opt + return ast.FunctionDef(name=ident_tok.value, args=args, returns=returns, + body=suite, decorator_list=[], + at_locs=[], keyword_loc=def_loc, name_loc=ident_tok.loc, + colon_loc=colon_loc, arrow_loc=arrow_loc, + loc=def_loc.join(suite[-1].loc)) + + parameters__26 = BeginEnd("(", Opt(Rule("varargslist")), ")", empty=_arguments) + """(2.6, 2.7) parameters: '(' [varargslist] ')'""" + + parameters__30 = BeginEnd("(", Opt(Rule("typedargslist")), ")", empty=_arguments) + """(3.0) parameters: '(' [typedargslist] ')'""" + + varargslist__26_1 = Seq(Rule("fpdef"), Opt(Seq(Loc("="), Rule("test")))) + + @action(Seq(Loc("**"), Tok("ident"))) + def varargslist__26_2(self, dstar_loc, kwarg_tok): + return self._arguments(kwarg=self._arg(kwarg_tok), + dstar_loc=dstar_loc, loc=dstar_loc.join(kwarg_tok.loc)) + + @action(Seq(Loc("*"), Tok("ident"), + Opt(Seq(Tok(","), Loc("**"), Tok("ident"))))) + def varargslist__26_3(self, star_loc, vararg_tok, kwarg_opt): + dstar_loc = kwarg = None + loc = star_loc.join(vararg_tok.loc) + vararg = self._arg(vararg_tok) + if kwarg_opt: + _, dstar_loc, kwarg_tok = kwarg_opt + kwarg = self._arg(kwarg_tok) + loc = star_loc.join(kwarg_tok.loc) + return self._arguments(vararg=vararg, kwarg=kwarg, + star_loc=star_loc, dstar_loc=dstar_loc, loc=loc) + + @action(Eps(value=())) + def varargslist__26_4(self): + return self._arguments() + + @action(Alt(Seq(Star(SeqN(0, varargslist__26_1, Tok(","))), + Alt(varargslist__26_2, varargslist__26_3)), + Seq(List(varargslist__26_1, ",", trailing=True), + varargslist__26_4))) + def varargslist__26(self, fparams, args): + """ + (2.6, 2.7) + varargslist: ((fpdef ['=' test] ',')* + ('*' NAME [',' '**' NAME] | '**' NAME) | + fpdef ['=' test] (',' fpdef ['=' test])* [',']) + """ + for fparam, default_opt in fparams: + if default_opt: + equals_loc, default = default_opt + args.equals_locs.append(equals_loc) + args.defaults.append(default) + elif len(args.defaults) > 0: + error = diagnostic.Diagnostic( + "fatal", "non-default argument follows default argument", {}, + fparam.loc, [args.args[-1].loc.join(args.defaults[-1].loc)]) + self.diagnostic_engine.process(error) + + args.args.append(fparam) + + def fparam_loc(fparam, default_opt): + if default_opt: + equals_loc, default = default_opt + return fparam.loc.join(default.loc) + else: + return fparam.loc + + if args.loc is None: + args.loc = fparam_loc(*fparams[0]).join(fparam_loc(*fparams[-1])) + elif len(fparams) > 0: + args.loc = args.loc.join(fparam_loc(*fparams[0])) + + return args + + @action(Tok("ident")) + def fpdef_1(self, ident_tok): + return ast.arg(arg=ident_tok.value, annotation=None, + arg_loc=ident_tok.loc, colon_loc=None, + loc=ident_tok.loc) + + fpdef = Alt(fpdef_1, BeginEnd("(", Rule("fplist"), ")", + empty=lambda self: ast.Tuple(elts=[], ctx=None, loc=None))) + """fpdef: NAME | '(' fplist ')'""" + + def _argslist(fpdef_rule, old_style=False): + argslist_1 = Seq(fpdef_rule, Opt(Seq(Loc("="), Rule("test")))) + + @action(Seq(Loc("**"), Tok("ident"))) + def argslist_2(self, dstar_loc, kwarg_tok): + return self._arguments(kwarg=self._arg(kwarg_tok), + dstar_loc=dstar_loc, loc=dstar_loc.join(kwarg_tok.loc)) + + @action(Seq(Loc("*"), Tok("ident"), + Star(SeqN(1, Tok(","), argslist_1)), + Opt(Seq(Tok(","), Loc("**"), Tok("ident"))))) + def argslist_3(self, star_loc, vararg_tok, fparams, kwarg_opt): + dstar_loc = kwarg = None + loc = star_loc.join(vararg_tok.loc) + vararg = self._arg(vararg_tok) + if kwarg_opt: + _, dstar_loc, kwarg_tok = kwarg_opt + kwarg = self._arg(kwarg_tok) + loc = star_loc.join(kwarg_tok.loc) + kwonlyargs, kw_defaults, kw_equals_locs = [], [], [] + for fparam, default_opt in fparams: + if default_opt: + equals_loc, default = default_opt + kw_equals_locs.append(equals_loc) + kw_defaults.append(default) + else: + kw_defaults.append(None) + kwonlyargs.append(fparam) + if any(kw_defaults): + loc = loc.join(kw_defaults[-1].loc) + elif any(kwonlyargs): + loc = loc.join(kwonlyargs[-1].loc) + return self._arguments(vararg=vararg, kwarg=kwarg, + kwonlyargs=kwonlyargs, kw_defaults=kw_defaults, + star_loc=star_loc, dstar_loc=dstar_loc, + kw_equals_locs=kw_equals_locs, loc=loc) + + argslist_4 = Alt(argslist_2, argslist_3) + + @action(Eps(value=())) + def argslist_5(self): + return self._arguments() + + if old_style: + argslist = Alt(Seq(Star(SeqN(0, argslist_1, Tok(","))), + argslist_4), + Seq(List(argslist_1, ",", trailing=True), + argslist_5)) + else: + argslist = Alt(Seq(Eps(value=[]), argslist_4), + Seq(List(argslist_1, ",", trailing=False), + Alt(SeqN(1, Tok(","), Alt(argslist_4, argslist_5)), + argslist_5))) + + def argslist_action(self, fparams, args): + for fparam, default_opt in fparams: + if default_opt: + equals_loc, default = default_opt + args.equals_locs.append(equals_loc) + args.defaults.append(default) + elif len(args.defaults) > 0: + error = diagnostic.Diagnostic( + "fatal", "non-default argument follows default argument", {}, + fparam.loc, [args.args[-1].loc.join(args.defaults[-1].loc)]) + self.diagnostic_engine.process(error) + + args.args.append(fparam) + + def fparam_loc(fparam, default_opt): + if default_opt: + equals_loc, default = default_opt + return fparam.loc.join(default.loc) + else: + return fparam.loc + + if args.loc is None: + args.loc = fparam_loc(*fparams[0]).join(fparam_loc(*fparams[-1])) + elif len(fparams) > 0: + args.loc = args.loc.join(fparam_loc(*fparams[0])) + + return args + + return action(argslist)(argslist_action) + + typedargslist__30 = _argslist(Rule("tfpdef"), old_style=True) + """ + (3.0, 3.1) + typedargslist: ((tfpdef ['=' test] ',')* + ('*' [tfpdef] (',' tfpdef ['=' test])* [',' '**' tfpdef] | '**' tfpdef) + | tfpdef ['=' test] (',' tfpdef ['=' test])* [',']) + """ + + typedargslist__32 = _argslist(Rule("tfpdef")) + """ + (3.2-) + typedargslist: (tfpdef ['=' test] (',' tfpdef ['=' test])* [',' + ['*' [tfpdef] (',' tfpdef ['=' test])* [',' '**' tfpdef] | '**' tfpdef]] + | '*' [tfpdef] (',' tfpdef ['=' test])* [',' '**' tfpdef] | '**' tfpdef) + """ + + varargslist__30 = _argslist(Rule("vfpdef"), old_style=True) + """ + (3.0, 3.1) + varargslist: ((vfpdef ['=' test] ',')* + ('*' [vfpdef] (',' vfpdef ['=' test])* [',' '**' vfpdef] | '**' vfpdef) + | vfpdef ['=' test] (',' vfpdef ['=' test])* [',']) + """ + + varargslist__32 = _argslist(Rule("vfpdef")) + """ + (3.2-) + varargslist: (vfpdef ['=' test] (',' vfpdef ['=' test])* [',' + ['*' [vfpdef] (',' vfpdef ['=' test])* [',' '**' vfpdef] | '**' vfpdef]] + | '*' [vfpdef] (',' vfpdef ['=' test])* [',' '**' vfpdef] | '**' vfpdef) + """ + + @action(Seq(Tok("ident"), Opt(Seq(Loc(":"), Rule("test"))))) + def tfpdef(self, ident_tok, annotation_opt): + """(3.0-) tfpdef: NAME [':' test]""" + if annotation_opt: + colon_loc, annotation = annotation_opt + return self._arg(ident_tok, colon_loc, annotation) + return self._arg(ident_tok) + + vfpdef = fpdef_1 + """(3.0-) vfpdef: NAME""" + + @action(List(Rule("fpdef"), ",", trailing=True)) + def fplist(self, elts): + """fplist: fpdef (',' fpdef)* [',']""" + return ast.Tuple(elts=elts, ctx=None, loc=None) + + stmt = Alt(Rule("simple_stmt"), Rule("compound_stmt")) + """stmt: simple_stmt | compound_stmt""" + + simple_stmt = SeqN(0, List(Rule("small_stmt"), ";", trailing=True), Tok("newline")) + """simple_stmt: small_stmt (';' small_stmt)* [';'] NEWLINE""" + + small_stmt = Alt(Rule("expr_stmt"), Rule("print_stmt"), Rule("del_stmt"), + Rule("pass_stmt"), Rule("flow_stmt"), Rule("import_stmt"), + Rule("global_stmt"), Rule("nonlocal_stmt"), Rule("exec_stmt"), + Rule("assert_stmt")) + """ + (2.6, 2.7) + small_stmt: (expr_stmt | print_stmt | del_stmt | pass_stmt | flow_stmt | + import_stmt | global_stmt | exec_stmt | assert_stmt) + (3.0-) + small_stmt: (expr_stmt | del_stmt | pass_stmt | flow_stmt | + import_stmt | global_stmt | nonlocal_stmt | assert_stmt) + """ + + expr_stmt_1__26 = Rule("testlist") + expr_stmt_1__32 = Rule("testlist_star_expr") + + @action(Seq(Rule("augassign"), Alt(Rule("yield_expr"), Rule("testlist")))) + def expr_stmt_2(self, augassign, rhs_expr): + return ast.AugAssign(op=augassign, value=rhs_expr) + + @action(Star(Seq(Loc("="), Alt(Rule("yield_expr"), Rule("expr_stmt_1"))))) + def expr_stmt_3(self, seq): + if len(seq) > 0: + return ast.Assign(targets=list(map(lambda x: x[1], seq[:-1])), value=seq[-1][1], + op_locs=list(map(lambda x: x[0], seq))) + else: + return None + + @action(Seq(Rule("expr_stmt_1"), Alt(expr_stmt_2, expr_stmt_3))) + def expr_stmt(self, lhs, rhs): + """ + (2.6, 2.7, 3.0, 3.1) + expr_stmt: testlist (augassign (yield_expr|testlist) | + ('=' (yield_expr|testlist))*) + (3.2-) + expr_stmt: testlist_star_expr (augassign (yield_expr|testlist) | + ('=' (yield_expr|testlist_star_expr))*) + """ + if isinstance(rhs, ast.AugAssign): + if isinstance(lhs, ast.Tuple) or isinstance(lhs, ast.List): + error = diagnostic.Diagnostic( + "fatal", "illegal expression for augmented assignment", {}, + rhs.op.loc, [lhs.loc]) + self.diagnostic_engine.process(error) + else: + rhs.target = self._assignable(lhs) + rhs.loc = rhs.target.loc.join(rhs.value.loc) + return rhs + elif rhs is not None: + rhs.targets = list(map(self._assignable, [lhs] + rhs.targets)) + rhs.loc = lhs.loc.join(rhs.value.loc) + return rhs + else: + return ast.Expr(value=lhs, loc=lhs.loc) + + testlist_star_expr = action( + List(Alt(Rule("test"), Rule("star_expr")), ",", trailing=True)) \ + (_wrap_tuple) + """(3.2-) testlist_star_expr: (test|star_expr) (',' (test|star_expr))* [',']""" + + augassign = Alt(Oper(ast.Add, "+="), Oper(ast.Sub, "-="), Oper(ast.MatMult, "@="), + Oper(ast.Mult, "*="), Oper(ast.Div, "/="), Oper(ast.Mod, "%="), + Oper(ast.BitAnd, "&="), Oper(ast.BitOr, "|="), Oper(ast.BitXor, "^="), + Oper(ast.LShift, "<<="), Oper(ast.RShift, ">>="), + Oper(ast.Pow, "**="), Oper(ast.FloorDiv, "//=")) + """augassign: ('+=' | '-=' | '*=' | '/=' | '%=' | '&=' | '|=' | '^=' | + '<<=' | '>>=' | '**=' | '//=')""" + + @action(List(Rule("test"), ",", trailing=True)) + def print_stmt_1(self, values): + nl, loc = True, values[-1].loc + if values.trailing_comma: + nl, loc = False, values.trailing_comma.loc + return ast.Print(dest=None, values=values, nl=nl, + dest_loc=None, loc=loc) + + @action(Seq(Loc(">>"), Rule("test"), Tok(","), List(Rule("test"), ",", trailing=True))) + def print_stmt_2(self, dest_loc, dest, comma_tok, values): + nl, loc = True, values[-1].loc + if values.trailing_comma: + nl, loc = False, values.trailing_comma.loc + return ast.Print(dest=dest, values=values, nl=nl, + dest_loc=dest_loc, loc=loc) + + @action(Eps()) + def print_stmt_3(self, eps): + return ast.Print(dest=None, values=[], nl=True, + dest_loc=None, loc=None) + + @action(Seq(Loc("print"), Alt(print_stmt_1, print_stmt_2, print_stmt_3))) + def print_stmt(self, print_loc, stmt): + """ + (2.6-2.7) + print_stmt: 'print' ( [ test (',' test)* [','] ] | + '>>' test [ (',' test)+ [','] ] ) + """ + stmt.keyword_loc = print_loc + if stmt.loc is None: + stmt.loc = print_loc + else: + stmt.loc = print_loc.join(stmt.loc) + return stmt + + @action(Seq(Loc("del"), List(Rule("expr"), ",", trailing=True))) + def del_stmt(self, stmt_loc, exprs): + # Python uses exprlist here, but does *not* obey the usual + # tuple-wrapping semantics, so we embed the rule directly. + """del_stmt: 'del' exprlist""" + return ast.Delete(targets=[self._assignable(expr, is_delete=True) for expr in exprs], + loc=stmt_loc.join(exprs[-1].loc), keyword_loc=stmt_loc) + + @action(Loc("pass")) + def pass_stmt(self, stmt_loc): + """pass_stmt: 'pass'""" + return ast.Pass(loc=stmt_loc, keyword_loc=stmt_loc) + + flow_stmt = Alt(Rule("break_stmt"), Rule("continue_stmt"), Rule("return_stmt"), + Rule("raise_stmt"), Rule("yield_stmt")) + """flow_stmt: break_stmt | continue_stmt | return_stmt | raise_stmt | yield_stmt""" + + @action(Loc("break")) + def break_stmt(self, stmt_loc): + """break_stmt: 'break'""" + return ast.Break(loc=stmt_loc, keyword_loc=stmt_loc) + + @action(Loc("continue")) + def continue_stmt(self, stmt_loc): + """continue_stmt: 'continue'""" + return ast.Continue(loc=stmt_loc, keyword_loc=stmt_loc) + + @action(Seq(Loc("return"), Opt(Rule("testlist")))) + def return_stmt(self, stmt_loc, values): + """return_stmt: 'return' [testlist]""" + loc = stmt_loc + if values: + loc = loc.join(values.loc) + return ast.Return(value=values, + loc=loc, keyword_loc=stmt_loc) + + @action(Rule("yield_expr")) + def yield_stmt(self, expr): + """yield_stmt: yield_expr""" + return ast.Expr(value=expr, loc=expr.loc) + + @action(Seq(Loc("raise"), Opt(Seq(Rule("test"), + Opt(Seq(Tok(","), Rule("test"), + Opt(SeqN(1, Tok(","), Rule("test"))))))))) + def raise_stmt__26(self, raise_loc, type_opt): + """(2.6, 2.7) raise_stmt: 'raise' [test [',' test [',' test]]]""" + type_ = inst = tback = None + loc = raise_loc + if type_opt: + type_, inst_opt = type_opt + loc = loc.join(type_.loc) + if inst_opt: + _, inst, tback = inst_opt + loc = loc.join(inst.loc) + if tback: + loc = loc.join(tback.loc) + return ast.Raise(exc=type_, inst=inst, tback=tback, cause=None, + keyword_loc=raise_loc, from_loc=None, loc=loc) + + @action(Seq(Loc("raise"), Opt(Seq(Rule("test"), Opt(Seq(Loc("from"), Rule("test"))))))) + def raise_stmt__30(self, raise_loc, exc_opt): + """(3.0-) raise_stmt: 'raise' [test ['from' test]]""" + exc = from_loc = cause = None + loc = raise_loc + if exc_opt: + exc, cause_opt = exc_opt + loc = loc.join(exc.loc) + if cause_opt: + from_loc, cause = cause_opt + loc = loc.join(cause.loc) + return ast.Raise(exc=exc, inst=None, tback=None, cause=cause, + keyword_loc=raise_loc, from_loc=from_loc, loc=loc) + + import_stmt = Alt(Rule("import_name"), Rule("import_from")) + """import_stmt: import_name | import_from""" + + @action(Seq(Loc("import"), Rule("dotted_as_names"))) + def import_name(self, import_loc, names): + """import_name: 'import' dotted_as_names""" + return ast.Import(names=names, + keyword_loc=import_loc, loc=import_loc.join(names[-1].loc)) + + @action(Loc(".")) + def import_from_1(self, loc): + return 1, loc + + @action(Loc("...")) + def import_from_2(self, loc): + return 3, loc + + @action(Seq(Star(Alt(import_from_1, import_from_2)), Rule("dotted_name"))) + def import_from_3(self, dots, dotted_name): + dots_loc, dots_count = None, 0 + if any(dots): + dots_loc = dots[0][1].join(dots[-1][1]) + dots_count = sum([count for count, loc in dots]) + return (dots_loc, dots_count), dotted_name + + @action(Plus(Alt(import_from_1, import_from_2))) + def import_from_4(self, dots): + dots_loc = dots[0][1].join(dots[-1][1]) + dots_count = sum([count for count, loc in dots]) + return (dots_loc, dots_count), None + + @action(Loc("*")) + def import_from_5(self, star_loc): + return (None, 0), \ + [ast.alias(name="*", asname=None, + name_loc=star_loc, as_loc=None, asname_loc=None, loc=star_loc)], \ + None + + @action(Rule("atom_5")) + def import_from_7(self, string): + return (None, 0), (string.loc, string.s) + + @action(Rule("import_as_names")) + def import_from_6(self, names): + return (None, 0), names, None + + @action(Seq(Loc("from"), Alt(import_from_3, import_from_4, import_from_7), + Loc("import"), Alt(import_from_5, + Seq(Loc("("), Rule("import_as_names"), Loc(")")), + import_from_6))) + def import_from(self, from_loc, module_name, import_loc, names): + """ + (2.6, 2.7) + import_from: ('from' ('.'* dotted_name | '.'+) + 'import' ('*' | '(' import_as_names ')' | import_as_names)) + (3.0-) + # note below: the ('.' | '...') is necessary because '...' is tokenized as ELLIPSIS + import_from: ('from' (('.' | '...')* dotted_name | ('.' | '...')+) + 'import' ('*' | '(' import_as_names ')' | import_as_names)) + """ + (dots_loc, dots_count), dotted_name_opt = module_name + module_loc = module = None + if dotted_name_opt: + module_loc, module = dotted_name_opt + lparen_loc, names, rparen_loc = names + loc = from_loc.join(names[-1].loc) + if rparen_loc: + loc = loc.join(rparen_loc) + + if module == "__future__": + self.add_flags([x.name for x in names]) + + return ast.ImportFrom(names=names, module=module, level=dots_count, + keyword_loc=from_loc, dots_loc=dots_loc, module_loc=module_loc, + import_loc=import_loc, lparen_loc=lparen_loc, rparen_loc=rparen_loc, + loc=loc) + + @action(Seq(Tok("ident"), Opt(Seq(Loc("as"), Tok("ident"))))) + def import_as_name(self, name_tok, as_name_opt): + """import_as_name: NAME ['as' NAME]""" + asname_name = asname_loc = as_loc = None + loc = name_tok.loc + if as_name_opt: + as_loc, asname = as_name_opt + asname_name = asname.value + asname_loc = asname.loc + loc = loc.join(asname.loc) + return ast.alias(name=name_tok.value, asname=asname_name, + loc=loc, name_loc=name_tok.loc, as_loc=as_loc, asname_loc=asname_loc) + + @action(Seq(Rule("dotted_name"), Opt(Seq(Loc("as"), Tok("ident"))))) + def dotted_as_name(self, dotted_name, as_name_opt): + """dotted_as_name: dotted_name ['as' NAME]""" + asname_name = asname_loc = as_loc = None + dotted_name_loc, dotted_name_name = dotted_name + loc = dotted_name_loc + if as_name_opt: + as_loc, asname = as_name_opt + asname_name = asname.value + asname_loc = asname.loc + loc = loc.join(asname.loc) + return ast.alias(name=dotted_name_name, asname=asname_name, + loc=loc, name_loc=dotted_name_loc, as_loc=as_loc, asname_loc=asname_loc) + + @action(Seq(Rule("atom_5"), Opt(Seq(Loc("as"), Tok("ident"))))) + def str_as_name(self, string, as_name_opt): + asname_name = asname_loc = as_loc = None + loc = string.loc + if as_name_opt: + as_loc, asname = as_name_opt + asname_name = asname.value + asname_loc = asname.loc + loc = loc.join(asname.loc) + return ast.alias(name=string.s, asname=asname_name, + loc=loc, name_loc=string.loc, as_loc=as_loc, asname_loc=asname_loc) + + import_as_names = List(Rule("import_as_name"), ",", trailing=True) + """import_as_names: import_as_name (',' import_as_name)* [',']""" + + dotted_as_names = List(Alt(Rule("dotted_as_name"), Rule("str_as_name")), ",", trailing=False) + """dotted_as_names: dotted_as_name (',' dotted_as_name)*""" + + @action(List(Tok("ident"), ".", trailing=False)) + def dotted_name(self, idents): + """dotted_name: NAME ('.' NAME)*""" + return idents[0].loc.join(idents[-1].loc), \ + ".".join(list(map(lambda x: x.value, idents))) + + @action(Seq(Loc("global"), List(Tok("ident"), ",", trailing=False))) + def global_stmt(self, global_loc, names): + """global_stmt: 'global' NAME (',' NAME)*""" + return ast.Global(names=list(map(lambda x: x.value, names)), + name_locs=list(map(lambda x: x.loc, names)), + keyword_loc=global_loc, loc=global_loc.join(names[-1].loc)) + + @action(Seq(Loc("exec"), Rule("expr"), + Opt(Seq(Loc("in"), Rule("test"), + Opt(SeqN(1, Loc(","), Rule("test"))))))) + def exec_stmt(self, exec_loc, body, in_opt): + """(2.6, 2.7) exec_stmt: 'exec' expr ['in' test [',' test]]""" + in_loc, globals, locals = None, None, None + loc = exec_loc.join(body.loc) + if in_opt: + in_loc, globals, locals = in_opt + if locals: + loc = loc.join(locals.loc) + else: + loc = loc.join(globals.loc) + return ast.Exec(body=body, locals=locals, globals=globals, + loc=loc, keyword_loc=exec_loc, in_loc=in_loc) + + @action(Seq(Loc("nonlocal"), List(Tok("ident"), ",", trailing=False))) + def nonlocal_stmt(self, nonlocal_loc, names): + """(3.0-) nonlocal_stmt: 'nonlocal' NAME (',' NAME)*""" + return ast.Nonlocal(names=list(map(lambda x: x.value, names)), + name_locs=list(map(lambda x: x.loc, names)), + keyword_loc=nonlocal_loc, loc=nonlocal_loc.join(names[-1].loc)) + + @action(Seq(Loc("assert"), Rule("test"), Opt(SeqN(1, Tok(","), Rule("test"))))) + def assert_stmt(self, assert_loc, test, msg): + """assert_stmt: 'assert' test [',' test]""" + loc = assert_loc.join(test.loc) + if msg: + loc = loc.join(msg.loc) + return ast.Assert(test=test, msg=msg, + loc=loc, keyword_loc=assert_loc) + + @action(Alt(Rule("if_stmt"), Rule("while_stmt"), Rule("for_stmt"), + Rule("try_stmt"), Rule("with_stmt"), Rule("funcdef"), + Rule("classdef"), Rule("decorated"))) + def compound_stmt(self, stmt): + """compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | + funcdef | classdef | decorated""" + return [stmt] + + @action(Seq(Loc("if"), Rule("test"), Loc(":"), Rule("suite"), + Star(Seq(Loc("elif"), Rule("test"), Loc(":"), Rule("suite"))), + Opt(Seq(Loc("else"), Loc(":"), Rule("suite"))))) + def if_stmt(self, if_loc, test, if_colon_loc, body, elifs, else_opt): + """if_stmt: 'if' test ':' suite ('elif' test ':' suite)* ['else' ':' suite]""" + stmt = ast.If(orelse=[], + else_loc=None, else_colon_loc=None) + + if else_opt: + stmt.else_loc, stmt.else_colon_loc, stmt.orelse = else_opt + + for elif_ in reversed(elifs): + stmt.keyword_loc, stmt.test, stmt.if_colon_loc, stmt.body = elif_ + stmt.loc = stmt.keyword_loc.join(stmt.body[-1].loc) + if stmt.orelse: + stmt.loc = stmt.loc.join(stmt.orelse[-1].loc) + stmt = ast.If(orelse=[stmt], + else_loc=None, else_colon_loc=None) + + stmt.keyword_loc, stmt.test, stmt.if_colon_loc, stmt.body = \ + if_loc, test, if_colon_loc, body + stmt.loc = stmt.keyword_loc.join(stmt.body[-1].loc) + if stmt.orelse: + stmt.loc = stmt.loc.join(stmt.orelse[-1].loc) + return stmt + + @action(Seq(Loc("while"), Rule("test"), Loc(":"), Rule("suite"), + Opt(Seq(Loc("else"), Loc(":"), Rule("suite"))))) + def while_stmt(self, while_loc, test, while_colon_loc, body, else_opt): + """while_stmt: 'while' test ':' suite ['else' ':' suite]""" + stmt = ast.While(test=test, body=body, orelse=[], + keyword_loc=while_loc, while_colon_loc=while_colon_loc, + else_loc=None, else_colon_loc=None, + loc=while_loc.join(body[-1].loc)) + if else_opt: + stmt.else_loc, stmt.else_colon_loc, stmt.orelse = else_opt + stmt.loc = stmt.loc.join(stmt.orelse[-1].loc) + + return stmt + + @action(Seq(Loc("for"), Rule("exprlist"), Loc("in"), Rule("testlist"), + Loc(":"), Rule("suite"), + Opt(Seq(Loc("else"), Loc(":"), Rule("suite"))))) + def for_stmt(self, for_loc, target, in_loc, iter, for_colon_loc, body, else_opt): + """for_stmt: 'for' exprlist 'in' testlist ':' suite ['else' ':' suite]""" + stmt = ast.For(target=self._assignable(target), iter=iter, body=body, orelse=[], + keyword_loc=for_loc, in_loc=in_loc, for_colon_loc=for_colon_loc, + else_loc=None, else_colon_loc=None, + loc=for_loc.join(body[-1].loc)) + if else_opt: + stmt.else_loc, stmt.else_colon_loc, stmt.orelse = else_opt + stmt.loc = stmt.loc.join(stmt.orelse[-1].loc) + + return stmt + + @action(Seq(Plus(Seq(Rule("except_clause"), Loc(":"), Rule("suite"))), + Opt(Seq(Loc("else"), Loc(":"), Rule("suite"))), + Opt(Seq(Loc("finally"), Loc(":"), Rule("suite"))))) + def try_stmt_1(self, clauses, else_opt, finally_opt): + handlers = [] + for clause in clauses: + handler, handler.colon_loc, handler.body = clause + handler.loc = handler.loc.join(handler.body[-1].loc) + handlers.append(handler) + + else_loc, else_colon_loc, orelse = None, None, [] + loc = handlers[-1].loc + if else_opt: + else_loc, else_colon_loc, orelse = else_opt + loc = orelse[-1].loc + + finally_loc, finally_colon_loc, finalbody = None, None, [] + if finally_opt: + finally_loc, finally_colon_loc, finalbody = finally_opt + loc = finalbody[-1].loc + stmt = ast.Try(body=None, handlers=handlers, orelse=orelse, finalbody=finalbody, + else_loc=else_loc, else_colon_loc=else_colon_loc, + finally_loc=finally_loc, finally_colon_loc=finally_colon_loc, + loc=loc) + return stmt + + @action(Seq(Loc("finally"), Loc(":"), Rule("suite"))) + def try_stmt_2(self, finally_loc, finally_colon_loc, finalbody): + return ast.Try(body=None, handlers=[], orelse=[], finalbody=finalbody, + else_loc=None, else_colon_loc=None, + finally_loc=finally_loc, finally_colon_loc=finally_colon_loc, + loc=finalbody[-1].loc) + + @action(Seq(Loc("try"), Loc(":"), Rule("suite"), Alt(try_stmt_1, try_stmt_2))) + def try_stmt(self, try_loc, try_colon_loc, body, stmt): + """ + try_stmt: ('try' ':' suite + ((except_clause ':' suite)+ + ['else' ':' suite] + ['finally' ':' suite] | + 'finally' ':' suite)) + """ + stmt.keyword_loc, stmt.try_colon_loc, stmt.body = \ + try_loc, try_colon_loc, body + stmt.loc = stmt.loc.join(try_loc) + return stmt + + @action(Seq(Loc("with"), Rule("test"), Opt(Rule("with_var")), Loc(":"), Rule("suite"))) + def with_stmt__26(self, with_loc, context, with_var, colon_loc, body): + """(2.6, 3.0) with_stmt: 'with' test [ with_var ] ':' suite""" + if with_var: + as_loc, optional_vars = with_var + item = ast.withitem(context_expr=context, optional_vars=optional_vars, + as_loc=as_loc, loc=context.loc.join(optional_vars.loc)) + else: + item = ast.withitem(context_expr=context, optional_vars=None, + as_loc=None, loc=context.loc) + return ast.With(items=[item], body=body, + keyword_loc=with_loc, colon_loc=colon_loc, + loc=with_loc.join(body[-1].loc)) + + with_var = Seq(Loc("as"), Rule("expr")) + """(2.6, 3.0) with_var: 'as' expr""" + + @action(Seq(Loc("with"), List(Rule("with_item"), ",", trailing=False), Loc(":"), + Rule("suite"))) + def with_stmt__27(self, with_loc, items, colon_loc, body): + """(2.7, 3.1-) with_stmt: 'with' with_item (',' with_item)* ':' suite""" + return ast.With(items=items, body=body, + keyword_loc=with_loc, colon_loc=colon_loc, + loc=with_loc.join(body[-1].loc)) + + @action(Seq(Rule("test"), Opt(Seq(Loc("as"), Rule("expr"))))) + def with_item(self, context, as_opt): + """(2.7, 3.1-) with_item: test ['as' expr]""" + if as_opt: + as_loc, optional_vars = as_opt + return ast.withitem(context_expr=context, optional_vars=optional_vars, + as_loc=as_loc, loc=context.loc.join(optional_vars.loc)) + else: + return ast.withitem(context_expr=context, optional_vars=None, + as_loc=None, loc=context.loc) + + @action(Seq(Alt(Loc("as"), Loc(",")), Rule("test"))) + def except_clause_1__26(self, as_loc, name): + return as_loc, None, name + + @action(Seq(Loc("as"), Tok("ident"))) + def except_clause_1__30(self, as_loc, name): + return as_loc, name, None + + @action(Seq(Loc("except"), + Opt(Seq(Rule("test"), + Opt(Rule("except_clause_1")))))) + def except_clause(self, except_loc, exc_opt): + """ + (2.6, 2.7) except_clause: 'except' [test [('as' | ',') test]] + (3.0-) except_clause: 'except' [test ['as' NAME]] + """ + type_ = name = as_loc = name_loc = None + loc = except_loc + if exc_opt: + type_, name_opt = exc_opt + loc = loc.join(type_.loc) + if name_opt: + as_loc, name_tok, name_node = name_opt + if name_tok: + name = name_tok.value + name_loc = name_tok.loc + else: + name = name_node + name_loc = name_node.loc + loc = loc.join(name_loc) + return ast.ExceptHandler(type=type_, name=name, + except_loc=except_loc, as_loc=as_loc, name_loc=name_loc, + loc=loc) + + @action(Plus(Rule("stmt"))) + def suite_1(self, stmts): + return reduce(list.__add__, stmts, []) + + suite = Alt(Rule("simple_stmt"), + SeqN(2, Tok("newline"), Tok("indent"), suite_1, Tok("dedent"))) + """suite: simple_stmt | NEWLINE INDENT stmt+ DEDENT""" + + # 2.x-only backwards compatibility start + testlist_safe = action(List(Rule("old_test"), ",", trailing=False))(_wrap_tuple) + """(2.6, 2.7) testlist_safe: old_test [(',' old_test)+ [',']]""" + + old_test = Alt(Rule("or_test"), Rule("old_lambdef")) + """(2.6, 2.7) old_test: or_test | old_lambdef""" + + @action(Seq(Loc("lambda"), Opt(Rule("varargslist")), Loc(":"), Rule("old_test"))) + def old_lambdef(self, lambda_loc, args_opt, colon_loc, body): + """(2.6, 2.7) old_lambdef: 'lambda' [varargslist] ':' old_test""" + if args_opt is None: + args_opt = self._arguments() + args_opt.loc = colon_loc.begin() + return ast.Lambda(args=args_opt, body=body, + lambda_loc=lambda_loc, colon_loc=colon_loc, + loc=lambda_loc.join(body.loc)) + # 2.x-only backwards compatibility end + + @action(Seq(Rule("or_test"), Opt(Seq(Loc("if"), Rule("or_test"), + Loc("else"), Rule("test"))))) + def test_1(self, lhs, rhs_opt): + if rhs_opt is not None: + if_loc, test, else_loc, orelse = rhs_opt + return ast.IfExp(test=test, body=lhs, orelse=orelse, + if_loc=if_loc, else_loc=else_loc, loc=lhs.loc.join(orelse.loc)) + return lhs + + test = Alt(test_1, Rule("lambdef")) + """test: or_test ['if' or_test 'else' test] | lambdef""" + + test_nocond = Alt(Rule("or_test"), Rule("lambdef_nocond")) + """(3.0-) test_nocond: or_test | lambdef_nocond""" + + def lambdef_action(self, lambda_loc, args_opt, colon_loc, body): + if args_opt is None: + args_opt = self._arguments() + args_opt.loc = colon_loc.begin() + return ast.Lambda(args=args_opt, body=body, + lambda_loc=lambda_loc, colon_loc=colon_loc, + loc=lambda_loc.join(body.loc)) + + lambdef = action( + Seq(Loc("lambda"), Opt(Rule("varargslist")), Loc(":"), Rule("test"))) \ + (lambdef_action) + """lambdef: 'lambda' [varargslist] ':' test""" + + lambdef_nocond = action( + Seq(Loc("lambda"), Opt(Rule("varargslist")), Loc(":"), Rule("test_nocond"))) \ + (lambdef_action) + """(3.0-) lambdef_nocond: 'lambda' [varargslist] ':' test_nocond""" + + @action(Seq(Rule("and_test"), Star(Seq(Loc("or"), Rule("and_test"))))) + def or_test(self, lhs, rhs): + """or_test: and_test ('or' and_test)*""" + if len(rhs) > 0: + return ast.BoolOp(op=ast.Or(), + values=[lhs] + list(map(lambda x: x[1], rhs)), + loc=lhs.loc.join(rhs[-1][1].loc), + op_locs=list(map(lambda x: x[0], rhs))) + else: + return lhs + + @action(Seq(Rule("not_test"), Star(Seq(Loc("and"), Rule("not_test"))))) + def and_test(self, lhs, rhs): + """and_test: not_test ('and' not_test)*""" + if len(rhs) > 0: + return ast.BoolOp(op=ast.And(), + values=[lhs] + list(map(lambda x: x[1], rhs)), + loc=lhs.loc.join(rhs[-1][1].loc), + op_locs=list(map(lambda x: x[0], rhs))) + else: + return lhs + + @action(Seq(Oper(ast.Not, "not"), Rule("not_test"))) + def not_test_1(self, op, operand): + return ast.UnaryOp(op=op, operand=operand, + loc=op.loc.join(operand.loc)) + + not_test = Alt(not_test_1, Rule("comparison")) + """not_test: 'not' not_test | comparison""" + + comparison_1__26 = Seq(Rule("expr"), Star(Seq(Rule("comp_op"), Rule("expr")))) + comparison_1__30 = Seq(Rule("star_expr"), Star(Seq(Rule("comp_op"), Rule("star_expr")))) + comparison_1__32 = comparison_1__26 + + @action(Rule("comparison_1")) + def comparison(self, lhs, rhs): + """ + (2.6, 2.7) comparison: expr (comp_op expr)* + (3.0, 3.1) comparison: star_expr (comp_op star_expr)* + (3.2-) comparison: expr (comp_op expr)* + """ + if len(rhs) > 0: + return ast.Compare(left=lhs, ops=list(map(lambda x: x[0], rhs)), + comparators=list(map(lambda x: x[1], rhs)), + loc=lhs.loc.join(rhs[-1][1].loc)) + else: + return lhs + + @action(Seq(Opt(Loc("*")), Rule("expr"))) + def star_expr__30(self, star_opt, expr): + """(3.0, 3.1) star_expr: ['*'] expr""" + if star_opt: + return ast.Starred(value=expr, ctx=None, + star_loc=star_opt, loc=expr.loc.join(star_opt)) + return expr + + @action(Seq(Loc("*"), Rule("expr"))) + def star_expr__32(self, star_loc, expr): + """(3.0-) star_expr: '*' expr""" + return ast.Starred(value=expr, ctx=None, + star_loc=star_loc, loc=expr.loc.join(star_loc)) + + comp_op = Alt(Oper(ast.Lt, "<"), Oper(ast.Gt, ">"), Oper(ast.Eq, "=="), + Oper(ast.GtE, ">="), Oper(ast.LtE, "<="), Oper(ast.NotEq, "<>"), + Oper(ast.NotEq, "!="), + Oper(ast.In, "in"), Oper(ast.NotIn, "not", "in"), + Oper(ast.IsNot, "is", "not"), Oper(ast.Is, "is")) + """ + (2.6, 2.7) comp_op: '<'|'>'|'=='|'>='|'<='|'<>'|'!='|'in'|'not' 'in'|'is'|'is' 'not' + (3.0-) comp_op: '<'|'>'|'=='|'>='|'<='|'!='|'in'|'not' 'in'|'is'|'is' 'not' + """ + + expr = BinOper("xor_expr", Oper(ast.BitOr, "|")) + """expr: xor_expr ('|' xor_expr)*""" + + xor_expr = BinOper("and_expr", Oper(ast.BitXor, "^")) + """xor_expr: and_expr ('^' and_expr)*""" + + and_expr = BinOper("shift_expr", Oper(ast.BitAnd, "&")) + """and_expr: shift_expr ('&' shift_expr)*""" + + shift_expr = BinOper("arith_expr", Alt(Oper(ast.LShift, "<<"), Oper(ast.RShift, ">>"))) + """shift_expr: arith_expr (('<<'|'>>') arith_expr)*""" + + arith_expr = BinOper("term", Alt(Oper(ast.Add, "+"), Oper(ast.Sub, "-"))) + """arith_expr: term (('+'|'-') term)*""" + + term = BinOper("factor", Alt(Oper(ast.Mult, "*"), Oper(ast.MatMult, "@"), + Oper(ast.Div, "/"), Oper(ast.Mod, "%"), + Oper(ast.FloorDiv, "//"))) + """term: factor (('*'|'/'|'%'|'//') factor)*""" + + @action(Seq(Alt(Oper(ast.UAdd, "+"), Oper(ast.USub, "-"), Oper(ast.Invert, "~")), + Rule("factor"))) + def factor_1(self, op, factor): + return ast.UnaryOp(op=op, operand=factor, + loc=op.loc.join(factor.loc)) + + factor = Alt(factor_1, Rule("power")) + """factor: ('+'|'-'|'~') factor | power""" + + @action(Seq(Rule("atom"), Star(Rule("trailer")), Opt(Seq(Loc("**"), Rule("factor"))))) + def power(self, atom, trailers, factor_opt): + """power: atom trailer* ['**' factor]""" + for trailer in trailers: + if isinstance(trailer, ast.Attribute) or isinstance(trailer, ast.Subscript): + trailer.value = atom + elif isinstance(trailer, ast.Call): + trailer.func = atom + trailer.loc = atom.loc.join(trailer.loc) + atom = trailer + if factor_opt: + op_loc, factor = factor_opt + return ast.BinOp(left=atom, op=ast.Pow(loc=op_loc), right=factor, + loc=atom.loc.join(factor.loc)) + return atom + + @action(Rule("testlist1")) + def atom_1(self, expr): + return ast.Repr(value=expr, loc=None) + + @action(Tok("ident")) + def atom_2(self, tok): + return ast.Name(id=tok.value, loc=tok.loc, ctx=None) + + @action(Alt(Tok("int"), Tok("float"), Tok("complex"))) + def atom_3(self, tok): + return ast.Num(n=tok.value, loc=tok.loc) + + @action(Seq(Tok("strbegin"), Tok("strdata"), Tok("strend"))) + def atom_4(self, begin_tok, data_tok, end_tok): + return ast.Str(s=data_tok.value, + begin_loc=begin_tok.loc, end_loc=end_tok.loc, + loc=begin_tok.loc.join(end_tok.loc)) + + @action(Plus(atom_4)) + def atom_5(self, strings): + joint = "" + if all(isinstance(x.s, bytes) for x in strings): + joint = b"" + return ast.Str(s=joint.join([x.s for x in strings]), + begin_loc=strings[0].begin_loc, end_loc=strings[-1].end_loc, + loc=strings[0].loc.join(strings[-1].loc)) + + atom_6__26 = Rule("dictmaker") + atom_6__27 = Rule("dictorsetmaker") + + atom__26 = Alt(BeginEnd("(", Opt(Alt(Rule("yield_expr"), Rule("testlist_comp"))), ")", + empty=lambda self: ast.Tuple(elts=[], ctx=None, loc=None)), + BeginEnd("[", Opt(Rule("listmaker")), "]", + empty=lambda self: ast.List(elts=[], ctx=None, loc=None)), + BeginEnd("{", Opt(Rule("atom_6")), "}", + empty=lambda self: ast.Dict(keys=[], values=[], colon_locs=[], + loc=None)), + BeginEnd("`", atom_1, "`"), + atom_2, atom_3, atom_5) + """ + (2.6) + atom: ('(' [yield_expr|testlist_gexp] ')' | + '[' [listmaker] ']' | + '{' [dictmaker] '}' | + '`' testlist1 '`' | + NAME | NUMBER | STRING+) + (2.7) + atom: ('(' [yield_expr|testlist_comp] ')' | + '[' [listmaker] ']' | + '{' [dictorsetmaker] '}' | + '`' testlist1 '`' | + NAME | NUMBER | STRING+) + """ + + @action(Loc("...")) + def atom_7(self, loc): + return ast.Ellipsis(loc=loc) + + @action(Alt(Tok("None"), Tok("True"), Tok("False"))) + def atom_8(self, tok): + if tok.kind == "None": + value = None + elif tok.kind == "True": + value = True + elif tok.kind == "False": + value = False + return ast.NameConstant(value=value, loc=tok.loc) + + atom__30 = Alt(BeginEnd("(", Opt(Alt(Rule("yield_expr"), Rule("testlist_comp"))), ")", + empty=lambda self: ast.Tuple(elts=[], ctx=None, loc=None)), + BeginEnd("[", Opt(Rule("testlist_comp__list")), "]", + empty=lambda self: ast.List(elts=[], ctx=None, loc=None)), + BeginEnd("{", Opt(Rule("dictorsetmaker")), "}", + empty=lambda self: ast.Dict(keys=[], values=[], colon_locs=[], + loc=None)), + atom_2, atom_3, atom_5, atom_7, atom_8) + """ + (3.0-) + atom: ('(' [yield_expr|testlist_comp] ')' | + '[' [testlist_comp] ']' | + '{' [dictorsetmaker] '}' | + NAME | NUMBER | STRING+ | '...' | 'None' | 'True' | 'False') + """ + + def list_gen_action(self, lhs, rhs): + if rhs is None: # (x) + return lhs + elif isinstance(rhs, ast.Tuple) or isinstance(rhs, ast.List): + rhs.elts = [lhs] + rhs.elts + return rhs + elif isinstance(rhs, ast.ListComp) or isinstance(rhs, ast.GeneratorExp): + rhs.elt = lhs + return rhs + + @action(Rule("list_for")) + def listmaker_1(self, compose): + return ast.ListComp(generators=compose([]), loc=None) + + @action(List(Rule("test"), ",", trailing=True, leading=False)) + def listmaker_2(self, elts): + return ast.List(elts=elts, ctx=None, loc=None) + + listmaker = action( + Seq(Rule("test"), + Alt(listmaker_1, listmaker_2))) \ + (list_gen_action) + """listmaker: test ( list_for | (',' test)* [','] )""" + + testlist_comp_1__26 = Rule("test") + testlist_comp_1__32 = Alt(Rule("test"), Rule("star_expr")) + + @action(Rule("comp_for")) + def testlist_comp_2(self, compose): + return ast.GeneratorExp(generators=compose([]), loc=None) + + @action(List(Rule("testlist_comp_1"), ",", trailing=True, leading=False)) + def testlist_comp_3(self, elts): + if elts == [] and not elts.trailing_comma: + return None + else: + return ast.Tuple(elts=elts, ctx=None, loc=None) + + testlist_comp = action( + Seq(Rule("testlist_comp_1"), Alt(testlist_comp_2, testlist_comp_3))) \ + (list_gen_action) + """ + (2.6) testlist_gexp: test ( gen_for | (',' test)* [','] ) + (2.7, 3.0, 3.1) testlist_comp: test ( comp_for | (',' test)* [','] ) + (3.2-) testlist_comp: (test|star_expr) ( comp_for | (',' (test|star_expr))* [','] ) + """ + + @action(Rule("comp_for")) + def testlist_comp__list_1(self, compose): + return ast.ListComp(generators=compose([]), loc=None) + + @action(List(Rule("testlist_comp_1"), ",", trailing=True, leading=False)) + def testlist_comp__list_2(self, elts): + return ast.List(elts=elts, ctx=None, loc=None) + + testlist_comp__list = action( + Seq(Rule("testlist_comp_1"), Alt(testlist_comp__list_1, testlist_comp__list_2))) \ + (list_gen_action) + """Same grammar as testlist_comp, but different semantic action.""" + + @action(Seq(Loc("."), Tok("ident"))) + def trailer_1(self, dot_loc, ident_tok): + return ast.Attribute(attr=ident_tok.value, ctx=None, + loc=dot_loc.join(ident_tok.loc), + attr_loc=ident_tok.loc, dot_loc=dot_loc) + + trailer = Alt(BeginEnd("(", Opt(Rule("arglist")), ")", + empty=_empty_arglist), + BeginEnd("[", Rule("subscriptlist"), "]"), + trailer_1) + """trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME""" + + @action(List(Rule("subscript"), ",", trailing=True)) + def subscriptlist(self, subscripts): + """subscriptlist: subscript (',' subscript)* [',']""" + if len(subscripts) == 1: + return ast.Subscript(slice=subscripts[0], ctx=None, loc=None) + elif all([isinstance(x, ast.Index) for x in subscripts]): + elts = [x.value for x in subscripts] + loc = subscripts[0].loc.join(subscripts[-1].loc) + index = ast.Index(value=ast.Tuple(elts=elts, ctx=None, + begin_loc=None, end_loc=None, loc=loc), + loc=loc) + return ast.Subscript(slice=index, ctx=None, loc=None) + else: + extslice = ast.ExtSlice(dims=subscripts, + loc=subscripts[0].loc.join(subscripts[-1].loc)) + return ast.Subscript(slice=extslice, ctx=None, loc=None) + + @action(Seq(Loc("."), Loc("."), Loc("."))) + def subscript_1(self, dot_1_loc, dot_2_loc, dot_3_loc): + return ast.Ellipsis(loc=dot_1_loc.join(dot_3_loc)) + + @action(Seq(Opt(Rule("test")), Loc(":"), Opt(Rule("test")), Opt(Rule("sliceop")))) + def subscript_2(self, lower_opt, colon_loc, upper_opt, step_opt): + loc = colon_loc + if lower_opt: + loc = loc.join(lower_opt.loc) + if upper_opt: + loc = loc.join(upper_opt.loc) + step_colon_loc = step = None + if step_opt: + step_colon_loc, step = step_opt + loc = loc.join(step_colon_loc) + if step: + loc = loc.join(step.loc) + return ast.Slice(lower=lower_opt, upper=upper_opt, step=step, + loc=loc, bound_colon_loc=colon_loc, step_colon_loc=step_colon_loc) + + @action(Rule("test")) + def subscript_3(self, expr): + return ast.Index(value=expr, loc=expr.loc) + + subscript__26 = Alt(subscript_1, subscript_2, subscript_3) + """(2.6, 2.7) subscript: '.' '.' '.' | test | [test] ':' [test] [sliceop]""" + + subscript__30 = Alt(subscript_2, subscript_3) + """(3.0-) subscript: test | [test] ':' [test] [sliceop]""" + + sliceop = Seq(Loc(":"), Opt(Rule("test"))) + """sliceop: ':' [test]""" + + exprlist_1__26 = List(Rule("expr"), ",", trailing=True) + exprlist_1__30 = List(Rule("star_expr"), ",", trailing=True) + exprlist_1__32 = List(Alt(Rule("expr"), Rule("star_expr")), ",", trailing=True) + + @action(Rule("exprlist_1")) + def exprlist(self, exprs): + """ + (2.6, 2.7) exprlist: expr (',' expr)* [','] + (3.0, 3.1) exprlist: star_expr (',' star_expr)* [','] + (3.2-) exprlist: (expr|star_expr) (',' (expr|star_expr))* [','] + """ + return self._wrap_tuple(exprs) + + @action(List(Rule("test"), ",", trailing=True)) + def testlist(self, exprs): + """testlist: test (',' test)* [',']""" + return self._wrap_tuple(exprs) + + @action(List(Seq(Rule("test"), Loc(":"), Rule("test")), ",", trailing=True)) + def dictmaker(self, elts): + """(2.6) dictmaker: test ':' test (',' test ':' test)* [',']""" + return ast.Dict(keys=list(map(lambda x: x[0], elts)), + values=list(map(lambda x: x[2], elts)), + colon_locs=list(map(lambda x: x[1], elts)), + loc=None) + + dictorsetmaker_1 = Seq(Rule("test"), Loc(":"), Rule("test")) + + @action(Seq(dictorsetmaker_1, + Alt(Rule("comp_for"), + List(dictorsetmaker_1, ",", leading=False, trailing=True)))) + def dictorsetmaker_2(self, first, elts): + if isinstance(elts, commalist): + elts.insert(0, first) + return ast.Dict(keys=list(map(lambda x: x[0], elts)), + values=list(map(lambda x: x[2], elts)), + colon_locs=list(map(lambda x: x[1], elts)), + loc=None) + else: + return ast.DictComp(key=first[0], value=first[2], generators=elts([]), + colon_loc=first[1], + begin_loc=None, end_loc=None, loc=None) + + @action(Seq(Rule("test"), + Alt(Rule("comp_for"), + List(Rule("test"), ",", leading=False, trailing=True)))) + def dictorsetmaker_3(self, first, elts): + if isinstance(elts, commalist): + elts.insert(0, first) + return ast.Set(elts=elts, loc=None) + else: + return ast.SetComp(elt=first, generators=elts([]), + begin_loc=None, end_loc=None, loc=None) + + dictorsetmaker = Alt(dictorsetmaker_2, dictorsetmaker_3) + """ + (2.7-) + dictorsetmaker: ( (test ':' test (comp_for | (',' test ':' test)* [','])) | + (test (comp_for | (',' test)* [','])) ) + """ + + @action(Seq(Loc("class"), Tok("ident"), + Opt(Seq(Loc("("), List(Rule("test"), ",", trailing=True), Loc(")"))), + Loc(":"), Rule("suite"))) + def classdef__26(self, class_loc, name_tok, bases_opt, colon_loc, body): + """(2.6, 2.7) classdef: 'class' NAME ['(' [testlist] ')'] ':' suite""" + bases, lparen_loc, rparen_loc = [], None, None + if bases_opt: + lparen_loc, bases, rparen_loc = bases_opt + + return ast.ClassDef(name=name_tok.value, bases=bases, keywords=[], + starargs=None, kwargs=None, body=body, + decorator_list=[], at_locs=[], + keyword_loc=class_loc, lparen_loc=lparen_loc, + star_loc=None, dstar_loc=None, rparen_loc=rparen_loc, + name_loc=name_tok.loc, colon_loc=colon_loc, + loc=class_loc.join(body[-1].loc)) + + @action(Seq(Loc("class"), Tok("ident"), + Opt(Seq(Loc("("), Rule("arglist"), Loc(")"))), + Loc(":"), Rule("suite"))) + def classdef__30(self, class_loc, name_tok, arglist_opt, colon_loc, body): + """(3.0) classdef: 'class' NAME ['(' [testlist] ')'] ':' suite""" + arglist, lparen_loc, rparen_loc = [], None, None + bases, keywords, starargs, kwargs = [], [], None, None + star_loc, dstar_loc = None, None + if arglist_opt: + lparen_loc, arglist, rparen_loc = arglist_opt + bases, keywords, starargs, kwargs = \ + arglist.args, arglist.keywords, arglist.starargs, arglist.kwargs + star_loc, dstar_loc = arglist.star_loc, arglist.dstar_loc + + return ast.ClassDef(name=name_tok.value, bases=bases, keywords=keywords, + starargs=starargs, kwargs=kwargs, body=body, + decorator_list=[], at_locs=[], + keyword_loc=class_loc, lparen_loc=lparen_loc, + star_loc=star_loc, dstar_loc=dstar_loc, rparen_loc=rparen_loc, + name_loc=name_tok.loc, colon_loc=colon_loc, + loc=class_loc.join(body[-1].loc)) + + @action(Seq(Loc("*"), Rule("test"), Star(SeqN(1, Tok(","), Rule("argument"))), + Opt(Seq(Tok(","), Loc("**"), Rule("test"))))) + def arglist_1(self, star_loc, stararg, postargs, kwarg_opt): + dstar_loc = kwarg = None + if kwarg_opt: + _, dstar_loc, kwarg = kwarg_opt + + for postarg in postargs: + if not isinstance(postarg, ast.keyword): + error = diagnostic.Diagnostic( + "fatal", "only named arguments may follow *expression", {}, + postarg.loc, [star_loc.join(stararg.loc)]) + self.diagnostic_engine.process(error) + + return postargs, \ + ast.Call(args=[], keywords=[], starargs=stararg, kwargs=kwarg, + star_loc=star_loc, dstar_loc=dstar_loc, loc=None) + + @action(Seq(Loc("**"), Rule("test"))) + def arglist_2(self, dstar_loc, kwarg): + return [], \ + ast.Call(args=[], keywords=[], starargs=None, kwargs=kwarg, + star_loc=None, dstar_loc=dstar_loc, loc=None) + + @action(Seq(Rule("argument"), + Alt(SeqN(1, Tok(","), Alt(Rule("arglist_1"), + Rule("arglist_2"), + Rule("arglist_3"), + Eps())), + Eps()))) + def arglist_3(self, arg, cont): + if cont is None: + return [arg], self._empty_arglist() + else: + args, rest = cont + return [arg] + args, rest + + @action(Alt(Rule("arglist_1"), + Rule("arglist_2"), + Rule("arglist_3"))) + def arglist(self, args, call): + """arglist: (argument ',')* (argument [','] | + '*' test (',' argument)* [',' '**' test] | + '**' test)""" + for arg in args: + if isinstance(arg, ast.keyword): + call.keywords.append(arg) + elif len(call.keywords) > 0: + error = diagnostic.Diagnostic( + "fatal", "non-keyword arg after keyword arg", {}, + arg.loc, [call.keywords[-1].loc]) + self.diagnostic_engine.process(error) + else: + call.args.append(arg) + return call + + @action(Seq(Loc("="), Rule("test"))) + def argument_1(self, equals_loc, rhs): + def thunk(lhs): + if not isinstance(lhs, ast.Name): + error = diagnostic.Diagnostic( + "fatal", "keyword must be an identifier", {}, lhs.loc) + self.diagnostic_engine.process(error) + return ast.keyword(arg=lhs.id, value=rhs, + loc=lhs.loc.join(rhs.loc), + arg_loc=lhs.loc, equals_loc=equals_loc) + return thunk + + @action(Opt(Rule("comp_for"))) + def argument_2(self, compose_opt): + def thunk(lhs): + if compose_opt: + generators = compose_opt([]) + return ast.GeneratorExp(elt=lhs, generators=generators, + begin_loc=None, end_loc=None, + loc=lhs.loc.join(generators[-1].loc)) + return lhs + return thunk + + @action(Seq(Rule("test"), Alt(argument_1, argument_2))) + def argument(self, lhs, thunk): + # This rule is reformulated to avoid exponential backtracking. + """ + (2.6) argument: test [gen_for] | test '=' test # Really [keyword '='] test + (2.7-) argument: test [comp_for] | test '=' test + """ + return thunk(lhs) + + list_iter = Alt(Rule("list_for"), Rule("list_if")) + """(2.6, 2.7) list_iter: list_for | list_if""" + + def list_comp_for_action(self, for_loc, target, in_loc, iter, next_opt): + def compose(comprehensions): + comp = ast.comprehension( + target=target, iter=iter, ifs=[], + loc=for_loc.join(iter.loc), for_loc=for_loc, in_loc=in_loc, if_locs=[]) + comprehensions += [comp] + if next_opt: + return next_opt(comprehensions) + else: + return comprehensions + return compose + + def list_comp_if_action(self, if_loc, cond, next_opt): + def compose(comprehensions): + comprehensions[-1].ifs.append(cond) + comprehensions[-1].if_locs.append(if_loc) + comprehensions[-1].loc = comprehensions[-1].loc.join(cond.loc) + if next_opt: + return next_opt(comprehensions) + else: + return comprehensions + return compose + + list_for = action( + Seq(Loc("for"), Rule("exprlist"), + Loc("in"), Rule("testlist_safe"), Opt(Rule("list_iter")))) \ + (list_comp_for_action) + """(2.6, 2.7) list_for: 'for' exprlist 'in' testlist_safe [list_iter]""" + + list_if = action( + Seq(Loc("if"), Rule("old_test"), Opt(Rule("list_iter")))) \ + (list_comp_if_action) + """(2.6, 2.7) list_if: 'if' old_test [list_iter]""" + + comp_iter = Alt(Rule("comp_for"), Rule("comp_if")) + """ + (2.6) gen_iter: gen_for | gen_if + (2.7-) comp_iter: comp_for | comp_if + """ + + comp_for = action( + Seq(Loc("for"), Rule("exprlist"), + Loc("in"), Rule("or_test"), Opt(Rule("comp_iter")))) \ + (list_comp_for_action) + """ + (2.6) gen_for: 'for' exprlist 'in' or_test [gen_iter] + (2.7-) comp_for: 'for' exprlist 'in' or_test [comp_iter] + """ + + comp_if__26 = action( + Seq(Loc("if"), Rule("old_test"), Opt(Rule("comp_iter")))) \ + (list_comp_if_action) + """ + (2.6) gen_if: 'if' old_test [gen_iter] + (2.7) comp_if: 'if' old_test [comp_iter] + """ + + comp_if__30 = action( + Seq(Loc("if"), Rule("test_nocond"), Opt(Rule("comp_iter")))) \ + (list_comp_if_action) + """ + (3.0-) comp_if: 'if' test_nocond [comp_iter] + """ + + testlist1 = action(List(Rule("test"), ",", trailing=False))(_wrap_tuple) + """testlist1: test (',' test)*""" + + @action(Seq(Loc("yield"), Opt(Rule("testlist")))) + def yield_expr__26(self, yield_loc, exprs): + """(2.6, 2.7, 3.0, 3.1, 3.2) yield_expr: 'yield' [testlist]""" + if exprs is not None: + return ast.Yield(value=exprs, + yield_loc=yield_loc, loc=yield_loc.join(exprs.loc)) + else: + return ast.Yield(value=None, + yield_loc=yield_loc, loc=yield_loc) + + @action(Seq(Loc("yield"), Opt(Rule("yield_arg")))) + def yield_expr__33(self, yield_loc, arg): + """(3.3-) yield_expr: 'yield' [yield_arg]""" + if isinstance(arg, ast.YieldFrom): + arg.yield_loc = yield_loc + arg.loc = arg.loc.join(arg.yield_loc) + return arg + elif arg is not None: + return ast.Yield(value=arg, + yield_loc=yield_loc, loc=yield_loc.join(arg.loc)) + else: + return ast.Yield(value=None, + yield_loc=yield_loc, loc=yield_loc) + + @action(Seq(Loc("from"), Rule("test"))) + def yield_arg_1(self, from_loc, value): + return ast.YieldFrom(value=value, + from_loc=from_loc, loc=from_loc.join(value.loc)) + + yield_arg = Alt(yield_arg_1, Rule("testlist")) + """(3.3-) yield_arg: 'from' test | testlist""" diff --git a/third_party/pythonparser/source.py b/third_party/pythonparser/source.py new file mode 100644 index 00000000..c40d9e7a --- /dev/null +++ b/third_party/pythonparser/source.py @@ -0,0 +1,310 @@ +""" +The :mod:`source` module concerns itself with manipulating +buffers of source code: creating ranges of characters corresponding +to a token, combining these ranges, extracting human-readable +location information and original source from a range. +""" + +from __future__ import absolute_import, division, print_function, unicode_literals +import bisect +import re + +class Buffer: + """ + A buffer containing source code and location information. + + :ivar source: (string) source code + :ivar name: (string) input filename or another description + of the input (e.g. ````). + :ivar line: (integer) first line of the input + """ + def __init__(self, source, name="", first_line=1): + self.encoding = self._extract_encoding(source) + if isinstance(source, bytes): + self.source = source.decode(self.encoding) + else: + self.source = source + self.name = name + self.first_line = first_line + self._line_begins = None + + def __repr__(self): + return "Buffer(\"%s\")" % self.name + + def source_line(self, lineno): + """ + Returns line ``lineno`` from source, taking ``first_line`` into account, + or raises :exc:`IndexError` if ``lineno`` is out of range. + """ + line_begins = self._extract_line_begins() + lineno = lineno - self.first_line + if lineno >= 0 and lineno + 1 < len(line_begins): + first, last = line_begins[lineno:lineno + 2] + return self.source[first:last] + elif lineno >= 0 and lineno < len(line_begins): + return self.source[line_begins[-1]:] + else: + raise IndexError + + def decompose_position(self, offset): + """ + Returns a ``line, column`` tuple for a character offset into the source, + orraises :exc:`IndexError` if ``lineno`` is out of range. + """ + line_begins = self._extract_line_begins() + lineno = bisect.bisect_right(line_begins, offset) - 1 + if offset >= 0 and offset <= len(self.source): + return lineno + self.first_line, offset - line_begins[lineno] + else: + raise IndexError + + def _extract_line_begins(self): + if self._line_begins: + return self._line_begins + + self._line_begins = [0] + index = None + while True: + index = self.source.find("\n", index) + 1 + if index == 0: + return self._line_begins + self._line_begins.append(index) + + _encoding_re = re.compile("^[ \t\v]*#.*?coding[:=][ \t]*([-_.a-zA-Z0-9]+)") + _encoding_bytes_re = re.compile(_encoding_re.pattern.encode()) + + def _extract_encoding(self, source): + if isinstance(source, bytes): + re = self._encoding_bytes_re + nl = b"\n" + else: + re = self._encoding_re + nl = "\n" + match = re.match(source) + if not match: + index = source.find(nl) + if index != -1: + match = re.match(source[index + 1:]) + if match: + encoding = match.group(1) + if isinstance(encoding, bytes): + return encoding.decode("ascii") + return encoding + return "ascii" + + +class Range: + """ + Location of an exclusive range of characters [*begin_pos*, *end_pos*) + in a :class:`Buffer`. + + :ivar begin_pos: (integer) offset of the first character + :ivar end_pos: (integer) offset of the character before the last + :ivar expanded_from: (Range or None) the range from which this range was expanded + """ + def __init__(self, source_buffer, begin_pos, end_pos, expanded_from=None): + self.source_buffer = source_buffer + self.begin_pos = begin_pos + self.end_pos = end_pos + self.expanded_from = expanded_from + + def __repr__(self): + """ + Returns a human-readable representation of this range. + """ + return "Range(\"%s\", %d, %d, %s)" % \ + (self.source_buffer.name, self.begin_pos, self.end_pos, repr(self.expanded_from)) + + def chain(self, expanded_from): + """ + Returns a range identical to this one, but indicating that + it was expanded from the range `expanded_from`. + """ + return Range(self.source_buffer, self.begin_pos, self.begin_pos, + expanded_from=expanded_from) + + def begin(self): + """ + Returns a zero-length range located just before the beginning of this range. + """ + return Range(self.source_buffer, self.begin_pos, self.begin_pos, + expanded_from=self.expanded_from) + + def end(self): + """ + Returns a zero-length range located just after the end of this range. + """ + return Range(self.source_buffer, self.end_pos, self.end_pos, + expanded_from=self.expanded_from) + + def size(self): + """ + Returns the amount of characters spanned by the range. + """ + return self.end_pos - self.begin_pos + + def column(self): + """ + Returns a zero-based column number of the beginning of this range. + """ + line, column = self.source_buffer.decompose_position(self.begin_pos) + return column + + def column_range(self): + """ + Returns a [*begin*, *end*) tuple describing the range of columns spanned + by this range. If range spans more than one line, returned *end* is + the last column of the line. + """ + if self.begin().line() == self.end().line(): + return self.begin().column(), self.end().column() + else: + return self.begin().column(), len(self.begin().source_line()) - 1 + + def line(self): + """ + Returns the line number of the beginning of this range. + """ + line, column = self.source_buffer.decompose_position(self.begin_pos) + return line + + def join(self, other): + """ + Returns the smallest possible range spanning both this range and other. + Raises :exc:`ValueError` if the ranges do not belong to the same + :class:`Buffer`. + """ + if self.source_buffer != other.source_buffer: + raise ValueError + if self.expanded_from == other.expanded_from: + expanded_from = self.expanded_from + else: + expanded_from = None + return Range(self.source_buffer, + min(self.begin_pos, other.begin_pos), + max(self.end_pos, other.end_pos), + expanded_from=expanded_from) + + def source(self): + """ + Returns the source code covered by this range. + """ + return self.source_buffer.source[self.begin_pos:self.end_pos] + + def source_line(self): + """ + Returns the line of source code containing the beginning of this range. + """ + return self.source_buffer.source_line(self.line()) + + def source_lines(self): + """ + Returns the lines of source code containing the entirety of this range. + """ + return [self.source_buffer.source_line(line) + for line in range(self.line(), self.end().line() + 1)] + + def __str__(self): + """ + Returns a Clang-style string representation of the beginning of this range. + """ + if self.begin_pos != self.end_pos: + return "%s:%d:%d-%d:%d" % (self.source_buffer.name, + self.line(), self.column() + 1, + self.end().line(), self.end().column() + 1) + else: + return "%s:%d:%d" % (self.source_buffer.name, + self.line(), self.column() + 1) + + def __eq__(self, other): + """ + Returns true if the ranges have the same source buffer, start and end position. + """ + return (type(self) == type(other) and + self.source_buffer == other.source_buffer and + self.begin_pos == other.begin_pos and + self.end_pos == other.end_pos and + self.expanded_from == other.expanded_from) + + def __ne__(self, other): + """ + Inverse of :meth:`__eq__`. + """ + return not (self == other) + + def __hash__(self): + return hash((self.source_buffer, self.begin_pos, self.end_pos, self.expanded_from)) + +class Comment: + """ + A comment in the source code. + + :ivar loc: (:class:`Range`) source location + :ivar text: (string) comment text + """ + + def __init__(self, loc, text): + self.loc, self.text = loc, text + +class RewriterConflict(Exception): + """ + An exception that is raised when two ranges supplied to a rewriter overlap. + + :ivar first: (:class:`Range`) first overlapping range + :ivar second: (:class:`Range`) second overlapping range + """ + + def __init__(self, first, second): + self.first, self.second = first, second + exception.__init__(self, "Ranges %s and %s overlap" % (repr(first), repr(second))) + +class Rewriter: + """ + The :class:`Rewriter` class rewrites source code: performs bulk modification + guided by a list of ranges and code fragments replacing their original + content. + + :ivar buffer: (:class:`Buffer`) buffer + """ + + def __init__(self, buffer): + self.buffer = buffer + self.ranges = [] + + def replace(self, range, replacement): + """Remove `range` and replace it with string `replacement`.""" + self.ranges.append((range, replacement)) + + def remove(self, range): + """Remove `range`.""" + self.replace(range, "") + + def insert_before(self, range, text): + """Insert `text` before `range`.""" + self.replace(range.begin(), text) + + def insert_after(self, range, text): + """Insert `text` after `range`.""" + self.replace(range.end(), text) + + def rewrite(self): + """Return the rewritten source. May raise :class:`RewriterConflict`.""" + self._sort() + self._check() + + rewritten, pos = [], 0 + for range, replacement in self.ranges: + rewritten.append(self.buffer.source[pos:range.begin_pos]) + rewritten.append(replacement) + pos = range.end_pos + rewritten.append(self.buffer.source[pos:]) + + return Buffer("".join(rewritten), self.buffer.name, self.buffer.first_line) + + def _sort(self): + self.ranges.sort(key=lambda x: x[0].begin_pos) + + def _check(self): + for (fst, _), (snd, _) in zip(self.ranges, self.ranges[1:]): + if snd.begin_pos < fst.end_pos: + raise RewriterConflict(fst, snd) diff --git a/third_party/stdlib/Queue.py b/third_party/stdlib/Queue.py new file mode 100644 index 00000000..62de009f --- /dev/null +++ b/third_party/stdlib/Queue.py @@ -0,0 +1,244 @@ +"""A multi-producer, multi-consumer queue.""" + +from time import time as _time +#try: +import threading as _threading +#except ImportError: +# import dummy_threading as _threading +from collections import deque +import heapq + +__all__ = ['Empty', 'Full', 'Queue', 'PriorityQueue', 'LifoQueue'] + +class Empty(Exception): + "Exception raised by Queue.get(block=0)/get_nowait()." + pass + +class Full(Exception): + "Exception raised by Queue.put(block=0)/put_nowait()." + pass + +class Queue(object): + """Create a queue object with a given maximum size. + + If maxsize is <= 0, the queue size is infinite. + """ + def __init__(self, maxsize=0): + self.maxsize = maxsize + self._init(maxsize) + # mutex must be held whenever the queue is mutating. All methods + # that acquire mutex must release it before returning. mutex + # is shared between the three conditions, so acquiring and + # releasing the conditions also acquires and releases mutex. + self.mutex = _threading.Lock() + # Notify not_empty whenever an item is added to the queue; a + # thread waiting to get is notified then. + self.not_empty = _threading.Condition(self.mutex) + # Notify not_full whenever an item is removed from the queue; + # a thread waiting to put is notified then. + self.not_full = _threading.Condition(self.mutex) + # Notify all_tasks_done whenever the number of unfinished tasks + # drops to zero; thread waiting to join() is notified to resume + self.all_tasks_done = _threading.Condition(self.mutex) + self.unfinished_tasks = 0 + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by Queue consumer threads. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items + have been processed (meaning that a task_done() call was received + for every item that had been put() into the queue). + + Raises a ValueError if called more times than there were items + placed in the queue. + """ + self.all_tasks_done.acquire() + try: + unfinished = self.unfinished_tasks - 1 + if unfinished <= 0: + if unfinished < 0: + raise ValueError('task_done() called too many times') + self.all_tasks_done.notify_all() + self.unfinished_tasks = unfinished + finally: + self.all_tasks_done.release() + + def join(self): + """Blocks until all items in the Queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate the item was retrieved and all work on it is complete. + + When the count of unfinished tasks drops to zero, join() unblocks. + """ + self.all_tasks_done.acquire() + try: + while self.unfinished_tasks: + self.all_tasks_done.wait() + finally: + self.all_tasks_done.release() + + def qsize(self): + """Return the approximate size of the queue (not reliable!).""" + self.mutex.acquire() + n = self._qsize() + self.mutex.release() + return n + + def empty(self): + """Return True if the queue is empty, False otherwise (not reliable!).""" + self.mutex.acquire() + n = not self._qsize() + self.mutex.release() + return n + + def full(self): + """Return True if the queue is full, False otherwise (not reliable!).""" + self.mutex.acquire() + n = 0 < self.maxsize == self._qsize() + self.mutex.release() + return n + + def put(self, item, block=True, timeout=None): + """Put an item into the queue. + + If optional args 'block' is true and 'timeout' is None (the default), + block if necessary until a free slot is available. If 'timeout' is + a non-negative number, it blocks at most 'timeout' seconds and raises + the Full exception if no free slot was available within that time. + Otherwise ('block' is false), put an item on the queue if a free slot + is immediately available, else raise the Full exception ('timeout' + is ignored in that case). + """ + self.not_full.acquire() + try: + if self.maxsize > 0: + if not block: + if self._qsize() == self.maxsize: + raise Full + elif timeout is None: + while self._qsize() == self.maxsize: + self.not_full.wait() + elif timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + else: + endtime = _time() + timeout + while self._qsize() == self.maxsize: + remaining = endtime - _time() + if remaining <= 0.0: + raise Full + self.not_full.wait(remaining) + self._put(item) + self.unfinished_tasks += 1 + self.not_empty.notify() + finally: + self.not_full.release() + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + Only enqueue the item if a free slot is immediately available. + Otherwise raise the Full exception. + """ + return self.put(item, False) + + def get(self, block=True, timeout=None): + """Remove and return an item from the queue. + + If optional args 'block' is true and 'timeout' is None (the default), + block if necessary until an item is available. If 'timeout' is + a non-negative number, it blocks at most 'timeout' seconds and raises + the Empty exception if no item was available within that time. + Otherwise ('block' is false), return an item if one is immediately + available, else raise the Empty exception ('timeout' is ignored + in that case). + """ + self.not_empty.acquire() + try: + if not block: + if not self._qsize(): + raise Empty + elif timeout is None: + while not self._qsize(): + self.not_empty.wait() + elif timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + else: + endtime = _time() + timeout + while not self._qsize(): + remaining = endtime - _time() + if remaining <= 0.0: + raise Empty + self.not_empty.wait(remaining) + item = self._get() + self.not_full.notify() + return item + finally: + self.not_empty.release() + + def get_nowait(self): + """Remove and return an item from the queue without blocking. + + Only get an item if one is immediately available. Otherwise + raise the Empty exception. + """ + return self.get(False) + + # Override these methods to implement other queue organizations + # (e.g. stack or priority queue). + # These will only be called with appropriate locks held + + # Initialize the queue representation + def _init(self, maxsize): + self.queue = deque() + + def _qsize(self, len=len): + return len(self.queue) + + # Put a new item in the queue + def _put(self, item): + self.queue.append(item) + + # Get an item from the queue + def _get(self): + return self.queue.popleft() + + +class PriorityQueue(Queue): + '''Variant of Queue that retrieves open entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + ''' + + def _init(self, maxsize): + self.queue = [] + + def _qsize(self, len=len): + return len(self.queue) + + def _put(self, item, heappush=heapq.heappush): + heappush(self.queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self.queue) + + +class LifoQueue(Queue): + '''Variant of Queue that retrieves most recently added entries first.''' + + def _init(self, maxsize): + self.queue = [] + + def _qsize(self, len=len): + return len(self.queue) + + def _put(self, item): + self.queue.append(item) + + def _get(self): + return self.queue.pop() diff --git a/third_party/stdlib/README.md b/third_party/stdlib/README.md new file mode 100644 index 00000000..9e95d841 --- /dev/null +++ b/third_party/stdlib/README.md @@ -0,0 +1,3 @@ +Canonical versions of the files in this folder come from the +[Lib](https://github.com/python/cpython/tree/2.7/Lib) directory of the +[2.7 branch of the CPython repo](https://github.com/python/cpython/tree/2.7). diff --git a/third_party/stdlib/_weakrefset.py b/third_party/stdlib/_weakrefset.py index 4365c0b5..e326dc62 100644 --- a/third_party/stdlib/_weakrefset.py +++ b/third_party/stdlib/_weakrefset.py @@ -2,7 +2,7 @@ # This code is separated-out because it is needed # by abc.py to load everything else at startup. -from __go__.grumpy import WeakRefType as ref +from '__go__/grumpy' import WeakRefType as ref __all__ = ['WeakSet'] diff --git a/third_party/stdlib/bisect.py b/third_party/stdlib/bisect.py new file mode 100644 index 00000000..d36f3657 --- /dev/null +++ b/third_party/stdlib/bisect.py @@ -0,0 +1,92 @@ +"""Bisection algorithms.""" + +def insort_right(a, x, lo=0, hi=None): + """Insert item x in list a, and keep it sorted assuming a is sorted. + + If x is already in a, insert it to the right of the rightmost x. + + Optional args lo (default 0) and hi (default len(a)) bound the + slice of a to be searched. + """ + + if lo < 0: + raise ValueError('lo must be non-negative') + if hi is None: + hi = len(a) + while lo < hi: + mid = (lo+hi)//2 + if x < a[mid]: hi = mid + else: lo = mid+1 + a.insert(lo, x) + +insort = insort_right # backward compatibility + +def bisect_right(a, x, lo=0, hi=None): + """Return the index where to insert item x in list a, assuming a is sorted. + + The return value i is such that all e in a[:i] have e <= x, and all e in + a[i:] have e > x. So if x already appears in the list, a.insert(x) will + insert just after the rightmost x already there. + + Optional args lo (default 0) and hi (default len(a)) bound the + slice of a to be searched. + """ + + if lo < 0: + raise ValueError('lo must be non-negative') + if hi is None: + hi = len(a) + while lo < hi: + mid = (lo+hi)//2 + if x < a[mid]: hi = mid + else: lo = mid+1 + return lo + +bisect = bisect_right # backward compatibility + +def insort_left(a, x, lo=0, hi=None): + """Insert item x in list a, and keep it sorted assuming a is sorted. + + If x is already in a, insert it to the left of the leftmost x. + + Optional args lo (default 0) and hi (default len(a)) bound the + slice of a to be searched. + """ + + if lo < 0: + raise ValueError('lo must be non-negative') + if hi is None: + hi = len(a) + while lo < hi: + mid = (lo+hi)//2 + if a[mid] < x: lo = mid+1 + else: hi = mid + a.insert(lo, x) + + +def bisect_left(a, x, lo=0, hi=None): + """Return the index where to insert item x in list a, assuming a is sorted. + + The return value i is such that all e in a[:i] have e < x, and all e in + a[i:] have e >= x. So if x already appears in the list, a.insert(x) will + insert just before the leftmost x already there. + + Optional args lo (default 0) and hi (default len(a)) bound the + slice of a to be searched. + """ + + if lo < 0: + raise ValueError('lo must be non-negative') + if hi is None: + hi = len(a) + while lo < hi: + mid = (lo+hi)//2 + if a[mid] < x: lo = mid+1 + else: hi = mid + return lo + +# Overwrite above definitions with a fast C implementation +# try: +# from _bisect import * +# except ImportError: +# pass diff --git a/third_party/stdlib/colorsys.py b/third_party/stdlib/colorsys.py new file mode 100644 index 00000000..bf7a9e5c --- /dev/null +++ b/third_party/stdlib/colorsys.py @@ -0,0 +1,159 @@ +"""Conversion functions between RGB and other color systems. +This modules provides two functions for each color system ABC: + rgb_to_abc(r, g, b) --> a, b, c + abc_to_rgb(a, b, c) --> r, g, b +All inputs and outputs are triples of floats in the range [0.0...1.0] +(with the exception of I and Q, which covers a slightly larger range). +Inputs outside the valid range may cause exceptions or invalid outputs. +Supported color systems: +RGB: Red, Green, Blue components +YIQ: Luminance, Chrominance (used by composite video signals) +HLS: Hue, Luminance, Saturation +HSV: Hue, Saturation, Value +""" + +# References: +# http://en.wikipedia.org/wiki/YIQ +# http://en.wikipedia.org/wiki/HLS_color_space +# http://en.wikipedia.org/wiki/HSV_color_space + +__all__ = ["rgb_to_yiq","yiq_to_rgb","rgb_to_hls","hls_to_rgb", + "rgb_to_hsv","hsv_to_rgb"] + +# Some floating point constants + +ONE_THIRD = 1.0/3.0 +ONE_SIXTH = 1.0/6.0 +TWO_THIRD = 2.0/3.0 + +# YIQ: used by composite video signals (linear combinations of RGB) +# Y: perceived grey level (0.0 == black, 1.0 == white) +# I, Q: color components +# +# There are a great many versions of the constants used in these formulae. +# The ones in this library uses constants from the FCC version of NTSC. + +def rgb_to_yiq(r, g, b): + y = 0.30*r + 0.59*g + 0.11*b + i = 0.74*(r-y) - 0.27*(b-y) + q = 0.48*(r-y) + 0.41*(b-y) + return (y, i, q) + +def yiq_to_rgb(y, i, q): + # r = y + (0.27*q + 0.41*i) / (0.74*0.41 + 0.27*0.48) + # b = y + (0.74*q - 0.48*i) / (0.74*0.41 + 0.27*0.48) + # g = y - (0.30*(r-y) + 0.11*(b-y)) / 0.59 + + r = y + 0.9468822170900693*i + 0.6235565819861433*q + g = y - 0.27478764629897834*i - 0.6356910791873801*q + b = y - 1.1085450346420322*i + 1.7090069284064666*q + + if r < 0.0: + r = 0.0 + if g < 0.0: + g = 0.0 + if b < 0.0: + b = 0.0 + if r > 1.0: + r = 1.0 + if g > 1.0: + g = 1.0 + if b > 1.0: + b = 1.0 + return (r, g, b) + + +# HLS: Hue, Luminance, Saturation +# H: position in the spectrum +# L: color lightness +# S: color saturation + +def rgb_to_hls(r, g, b): + maxc = max(r, g, b) + minc = min(r, g, b) + # XXX Can optimize (maxc+minc) and (maxc-minc) + l = (minc+maxc)/2.0 + if minc == maxc: + return 0.0, l, 0.0 + if l <= 0.5: + s = (maxc-minc) / (maxc+minc) + else: + s = (maxc-minc) / (2.0-maxc-minc) + rc = (maxc-r) / (maxc-minc) + gc = (maxc-g) / (maxc-minc) + bc = (maxc-b) / (maxc-minc) + if r == maxc: + h = bc-gc + elif g == maxc: + h = 2.0+rc-bc + else: + h = 4.0+gc-rc + h = (h/6.0) % 1.0 + return h, l, s + +def hls_to_rgb(h, l, s): + if s == 0.0: + return l, l, l + if l <= 0.5: + m2 = l * (1.0+s) + else: + m2 = l+s-(l*s) + m1 = 2.0*l - m2 + return (_v(m1, m2, h+ONE_THIRD), _v(m1, m2, h), _v(m1, m2, h-ONE_THIRD)) + +def _v(m1, m2, hue): + hue = hue % 1.0 + if hue < ONE_SIXTH: + return m1 + (m2-m1)*hue*6.0 + if hue < 0.5: + return m2 + if hue < TWO_THIRD: + return m1 + (m2-m1)*(TWO_THIRD-hue)*6.0 + return m1 + + +# HSV: Hue, Saturation, Value +# H: position in the spectrum +# S: color saturation ("purity") +# V: color brightness + +def rgb_to_hsv(r, g, b): + maxc = max(r, g, b) + minc = min(r, g, b) + v = maxc + if minc == maxc: + return 0.0, 0.0, v + s = (maxc-minc) / maxc + rc = (maxc-r) / (maxc-minc) + gc = (maxc-g) / (maxc-minc) + bc = (maxc-b) / (maxc-minc) + if r == maxc: + h = bc-gc + elif g == maxc: + h = 2.0+rc-bc + else: + h = 4.0+gc-rc + h = (h/6.0) % 1.0 + return h, s, v + +def hsv_to_rgb(h, s, v): + if s == 0.0: + return v, v, v + i = int(h*6.0) # XXX assume int() truncates! + f = (h*6.0) - i + p = v*(1.0 - s) + q = v*(1.0 - s*f) + t = v*(1.0 - s*(1.0-f)) + i = i%6 + if i == 0: + return v, t, p + if i == 1: + return q, v, p + if i == 2: + return p, v, t + if i == 3: + return p, q, v + if i == 4: + return t, p, v + if i == 5: + return v, p, q diff --git a/third_party/stdlib/copy.py b/third_party/stdlib/copy.py index 95cee440..ef83ef20 100644 --- a/third_party/stdlib/copy.py +++ b/third_party/stdlib/copy.py @@ -48,7 +48,7 @@ class instances). "pickle" for information on these methods. """ -from __go__.grumpy import WeakRefType +from '__go__/grumpy' import WeakRefType import types #from copy_reg import dispatch_table import copy_reg diff --git a/third_party/stdlib/dircache.py b/third_party/stdlib/dircache.py new file mode 100644 index 00000000..7e4f0b50 --- /dev/null +++ b/third_party/stdlib/dircache.py @@ -0,0 +1,41 @@ +"""Read and cache directory listings. + +The listdir() routine returns a sorted list of the files in a directory, +using a cache to avoid reading the directory more often than necessary. +The annotate() routine appends slashes to directories.""" +from warnings import warnpy3k +warnpy3k("the dircache module has been removed in Python 3.0", stacklevel=2) +del warnpy3k + +import os + +__all__ = ["listdir", "opendir", "annotate", "reset"] + +cache = {} + +def reset(): + """Reset the cache completely.""" + global cache + cache = {} + +def listdir(path): + """List directory contents, using cache.""" + try: + cached_mtime, list = cache[path] + del cache[path] + except KeyError: + cached_mtime, list = -1, [] + mtime = os.stat(path).st_mtime + if mtime != cached_mtime: + list = os.listdir(path) + list.sort() + cache[path] = mtime, list + return list + +opendir = listdir # XXX backward compatibility + +def annotate(head, list): + """Add '/' suffixes to directories.""" + for i in range(len(list)): + if os.path.isdir(os.path.join(head, list[i])): + list[i] = list[i] + '/' diff --git a/third_party/stdlib/dummy_thread.py b/third_party/stdlib/dummy_thread.py new file mode 100644 index 00000000..198dc49d --- /dev/null +++ b/third_party/stdlib/dummy_thread.py @@ -0,0 +1,145 @@ +"""Drop-in replacement for the thread module. + +Meant to be used as a brain-dead substitute so that threaded code does +not need to be rewritten for when the thread module is not present. + +Suggested usage is:: + + try: + import thread + except ImportError: + import dummy_thread as thread + +""" +# Exports only things specified by thread documentation; +# skipping obsolete synonyms allocate(), start_new(), exit_thread(). +__all__ = ['error', 'start_new_thread', 'exit', 'get_ident', 'allocate_lock', + 'interrupt_main', 'LockType'] + +import traceback as _traceback + +class error(Exception): + """Dummy implementation of thread.error.""" + + def __init__(self, *args): + self.args = args + +def start_new_thread(function, args, kwargs={}): + """Dummy implementation of thread.start_new_thread(). + + Compatibility is maintained by making sure that ``args`` is a + tuple and ``kwargs`` is a dictionary. If an exception is raised + and it is SystemExit (which can be done by thread.exit()) it is + caught and nothing is done; all other exceptions are printed out + by using traceback.print_exc(). + + If the executed function calls interrupt_main the KeyboardInterrupt will be + raised when the function returns. + + """ + if type(args) != type(tuple()): + raise TypeError("2nd arg must be a tuple") + if type(kwargs) != type(dict()): + raise TypeError("3rd arg must be a dict") + global _main + _main = False + try: + function(*args, **kwargs) + except SystemExit: + pass + except: + _traceback.print_exc() + _main = True + global _interrupt + if _interrupt: + _interrupt = False + raise KeyboardInterrupt + +def exit(): + """Dummy implementation of thread.exit().""" + raise SystemExit + +def get_ident(): + """Dummy implementation of thread.get_ident(). + + Since this module should only be used when threadmodule is not + available, it is safe to assume that the current process is the + only thread. Thus a constant can be safely returned. + """ + return -1 + +def allocate_lock(): + """Dummy implementation of thread.allocate_lock().""" + return LockType() + +def stack_size(size=None): + """Dummy implementation of thread.stack_size().""" + if size is not None: + raise error("setting thread stack size not supported") + return 0 + +class LockType(object): + """Class implementing dummy implementation of thread.LockType. + + Compatibility is maintained by maintaining self.locked_status + which is a boolean that stores the state of the lock. Pickling of + the lock, though, should not be done since if the thread module is + then used with an unpickled ``lock()`` from here problems could + occur from this class not having atomic methods. + + """ + + def __init__(self): + self.locked_status = False + + def acquire(self, waitflag=None): + """Dummy implementation of acquire(). + + For blocking calls, self.locked_status is automatically set to + True and returned appropriately based on value of + ``waitflag``. If it is non-blocking, then the value is + actually checked and not set if it is already acquired. This + is all done so that threading.Condition's assert statements + aren't triggered and throw a little fit. + + """ + if waitflag is None or waitflag: + self.locked_status = True + return True + else: + if not self.locked_status: + self.locked_status = True + return True + else: + return False + + __enter__ = acquire + + def __exit__(self, typ, val, tb): + self.release() + + def release(self): + """Release the dummy lock.""" + # XXX Perhaps shouldn't actually bother to test? Could lead + # to problems for complex, threaded code. + if not self.locked_status: + raise error + self.locked_status = False + return True + + def locked(self): + return self.locked_status + +# Used to signal that interrupt_main was called in a "thread" +_interrupt = False +# True when not executing in a "thread" +_main = True + +def interrupt_main(): + """Set _interrupt flag to True to have start_new_thread raise + KeyboardInterrupt upon exiting.""" + if _main: + raise KeyboardInterrupt + else: + global _interrupt + _interrupt = True diff --git a/third_party/stdlib/fpformat.py b/third_party/stdlib/fpformat.py new file mode 100644 index 00000000..71cbb25f --- /dev/null +++ b/third_party/stdlib/fpformat.py @@ -0,0 +1,145 @@ +"""General floating point formatting functions. + +Functions: +fix(x, digits_behind) +sci(x, digits_behind) + +Each takes a number or a string and a number of digits as arguments. + +Parameters: +x: number to be formatted; or a string resembling a number +digits_behind: number of digits behind the decimal point +""" +from warnings import warnpy3k +warnpy3k("the fpformat module has been removed in Python 3.0", stacklevel=2) +del warnpy3k + +import re + +__all__ = ["fix","sci","NotANumber"] + +# Compiled regular expression to "decode" a number +decoder = re.compile(r'^([-+]?)0*(\d*)((?:\.\d*)?)(([eE][-+]?\d+)?)$') +# \0 the whole thing +# \1 leading sign or empty +# \2 digits left of decimal point +# \3 fraction (empty or begins with point) +# \4 exponent part (empty or begins with 'e' or 'E') + +try: + class NotANumber(ValueError): + pass +except TypeError: + NotANumber = 'fpformat.NotANumber' + +def extract(s): + """Return (sign, intpart, fraction, expo) or raise an exception: + sign is '+' or '-' + intpart is 0 or more digits beginning with a nonzero + fraction is 0 or more digits + expo is an integer""" + res = decoder.match(s) + if res is None: raise NotANumber, s + sign, intpart, fraction, exppart = res.group(1,2,3,4) + if sign == '+': sign = '' + if fraction: fraction = fraction[1:] + if exppart: expo = int(exppart[1:]) + else: expo = 0 + return sign, intpart, fraction, expo + +def unexpo(intpart, fraction, expo): + """Remove the exponent by changing intpart and fraction.""" + if expo > 0: # Move the point left + f = len(fraction) + intpart, fraction = intpart + fraction[:expo], fraction[expo:] + if expo > f: + intpart = intpart + '0'*(expo-f) + elif expo < 0: # Move the point right + i = len(intpart) + intpart, fraction = intpart[:expo], intpart[expo:] + fraction + if expo < -i: + fraction = '0'*(-expo-i) + fraction + return intpart, fraction + +def roundfrac(intpart, fraction, digs): + """Round or extend the fraction to size digs.""" + f = len(fraction) + if f <= digs: + return intpart, fraction + '0'*(digs-f) + i = len(intpart) + if i+digs < 0: + return '0'*-digs, '' + total = intpart + fraction + nextdigit = total[i+digs] + if nextdigit >= '5': # Hard case: increment last digit, may have carry! + n = i + digs - 1 + while n >= 0: + if total[n] != '9': break + n = n-1 + else: + total = '0' + total + i = i+1 + n = 0 + total = total[:n] + chr(ord(total[n]) + 1) + '0'*(len(total)-n-1) + intpart, fraction = total[:i], total[i:] + if digs >= 0: + return intpart, fraction[:digs] + else: + return intpart[:digs] + '0'*-digs, '' + +def fix(x, digs): + """Format x as [-]ddd.ddd with 'digs' digits after the point + and at least one digit before. + If digs <= 0, the point is suppressed.""" + if type(x) != type(''): x = repr(x) + try: + sign, intpart, fraction, expo = extract(x) + except NotANumber: + return x + intpart, fraction = unexpo(intpart, fraction, expo) + intpart, fraction = roundfrac(intpart, fraction, digs) + while intpart and intpart[0] == '0': intpart = intpart[1:] + if intpart == '': intpart = '0' + if digs > 0: return sign + intpart + '.' + fraction + else: return sign + intpart + +def sci(x, digs): + """Format x as [-]d.dddE[+-]ddd with 'digs' digits after the point + and exactly one digit before. + If digs is <= 0, one digit is kept and the point is suppressed.""" + if type(x) != type(''): x = repr(x) + sign, intpart, fraction, expo = extract(x) + if not intpart: + while fraction and fraction[0] == '0': + fraction = fraction[1:] + expo = expo - 1 + if fraction: + intpart, fraction = fraction[0], fraction[1:] + expo = expo - 1 + else: + intpart = '0' + else: + expo = expo + len(intpart) - 1 + intpart, fraction = intpart[0], intpart[1:] + fraction + digs = max(0, digs) + intpart, fraction = roundfrac(intpart, fraction, digs) + if len(intpart) > 1: + intpart, fraction, expo = \ + intpart[0], intpart[1:] + fraction[:-1], \ + expo + len(intpart) - 1 + s = sign + intpart + if digs > 0: s = s + '.' + fraction + e = repr(abs(expo)) + e = '0'*(3-len(e)) + e + if expo < 0: e = '-' + e + else: e = '+' + e + return s + 'e' + e + +def test(): + """Interactive test run.""" + try: + while 1: + x, digs = input('Enter (x, digs): ') + print x, fix(x, digs), sci(x, digs) + except (EOFError, KeyboardInterrupt): + pass diff --git a/third_party/stdlib/genericpath.py b/third_party/stdlib/genericpath.py new file mode 100644 index 00000000..2648e545 --- /dev/null +++ b/third_party/stdlib/genericpath.py @@ -0,0 +1,113 @@ +""" +Path operations common to more than one OS +Do not use directly. The OS specific modules import the appropriate +functions from this module themselves. +""" +import os +import stat + +__all__ = ['commonprefix', 'exists', 'getatime', 'getctime', 'getmtime', + 'getsize', 'isdir', 'isfile'] + + +try: + _unicode = unicode +except NameError: + # If Python is built without Unicode support, the unicode type + # will not exist. Fake one. + class _unicode(object): + pass + +# Does a path exist? +# This is false for dangling symbolic links on systems that support them. +def exists(path): + """Test whether a path exists. Returns False for broken symbolic links""" + try: + os.stat(path) + except os.error: + return False + return True + + +# This follows symbolic links, so both islink() and isdir() can be true +# for the same path on systems that support symlinks +def isfile(path): + """Test whether a path is a regular file""" + try: + st = os.stat(path) + except os.error: + return False + return stat.S_ISREG(st.st_mode) + + +# Is a path a directory? +# This follows symbolic links, so both islink() and isdir() +# can be true for the same path on systems that support symlinks +def isdir(s): + """Return true if the pathname refers to an existing directory.""" + try: + st = os.stat(s) + except os.error: + return False + return stat.S_ISDIR(st.st_mode) + + +def getsize(filename): + """Return the size of a file, reported by os.stat().""" + return os.stat(filename).st_size + + +def getmtime(filename): + """Return the last modification time of a file, reported by os.stat().""" + return os.stat(filename).st_mtime + + +def getatime(filename): + """Return the last access time of a file, reported by os.stat().""" + return os.stat(filename).st_atime + + +def getctime(filename): + """Return the metadata change time of a file, reported by os.stat().""" + return os.stat(filename).st_ctime + + +# Return the longest prefix of all list elements. +def commonprefix(m): + "Given a list of pathnames, returns the longest common leading component" + if not m: return '' + s1 = min(m) + s2 = max(m) + for i, c in enumerate(s1): + if c != s2[i]: + return s1[:i] + return s1 + +# Split a path in root and extension. +# The extension is everything starting at the last dot in the last +# pathname component; the root is everything before that. +# It is always true that root + ext == p. + +# Generic implementation of splitext, to be parametrized with +# the separators +def _splitext(p, sep, altsep, extsep): + """Split the extension from a pathname. + + Extension is everything from the last dot to the end, ignoring + leading dots. Returns "(root, ext)"; ext may be empty.""" + + sepIndex = p.rfind(sep) + if altsep: + altsepIndex = p.rfind(altsep) + sepIndex = max(sepIndex, altsepIndex) + + dotIndex = p.rfind(extsep) + if dotIndex > sepIndex: + # skip all leading dots + filenameIndex = sepIndex + 1 + while filenameIndex < dotIndex: + if p[filenameIndex] != extsep: + return p[:dotIndex], p[dotIndex:] + filenameIndex += 1 + + return p, '' diff --git a/third_party/stdlib/mimetools.py b/third_party/stdlib/mimetools.py new file mode 100644 index 00000000..30e2ce9d --- /dev/null +++ b/third_party/stdlib/mimetools.py @@ -0,0 +1,250 @@ +"""Various tools used by MIME-reading or MIME-writing programs.""" + + +import os +import sys +import tempfile +from warnings import filterwarnings, catch_warnings +with catch_warnings(): + if sys.py3kwarning: + filterwarnings("ignore", ".*rfc822 has been removed", DeprecationWarning) + import rfc822 + +from warnings import warnpy3k +warnpy3k("in 3.x, mimetools has been removed in favor of the email package", + stacklevel=2) + +__all__ = ["Message","choose_boundary","encode","decode","copyliteral", + "copybinary"] + +class Message(rfc822.Message): + """A derived class of rfc822.Message that knows about MIME headers and + contains some hooks for decoding encoded and multipart messages.""" + + def __init__(self, fp, seekable = 1): + rfc822.Message.__init__(self, fp, seekable) + self.encodingheader = \ + self.getheader('content-transfer-encoding') + self.typeheader = \ + self.getheader('content-type') + self.parsetype() + self.parseplist() + + def parsetype(self): + str = self.typeheader + if str is None: + str = 'text/plain' + if ';' in str: + i = str.index(';') + self.plisttext = str[i:] + str = str[:i] + else: + self.plisttext = '' + fields = str.split('/') + for i in range(len(fields)): + fields[i] = fields[i].strip().lower() + self.type = '/'.join(fields) + self.maintype = fields[0] + self.subtype = '/'.join(fields[1:]) + + def parseplist(self): + str = self.plisttext + self.plist = [] + while str[:1] == ';': + str = str[1:] + if ';' in str: + # XXX Should parse quotes! + end = str.index(';') + else: + end = len(str) + f = str[:end] + if '=' in f: + i = f.index('=') + f = f[:i].strip().lower() + \ + '=' + f[i+1:].strip() + self.plist.append(f.strip()) + str = str[end:] + + def getplist(self): + return self.plist + + def getparam(self, name): + name = name.lower() + '=' + n = len(name) + for p in self.plist: + if p[:n] == name: + return rfc822.unquote(p[n:]) + return None + + def getparamnames(self): + result = [] + for p in self.plist: + i = p.find('=') + if i >= 0: + result.append(p[:i].lower()) + return result + + def getencoding(self): + if self.encodingheader is None: + return '7bit' + return self.encodingheader.lower() + + def gettype(self): + return self.type + + def getmaintype(self): + return self.maintype + + def getsubtype(self): + return self.subtype + + + + +# Utility functions +# ----------------- + +#try: +import thread +#except ImportError: +# import dummy_thread as thread +_counter_lock = thread.allocate_lock() +del thread + +_counter = 0 +def _get_next_counter(): + global _counter + _counter_lock.acquire() + _counter += 1 + result = _counter + _counter_lock.release() + return result + +_prefix = None + +#def choose_boundary(): +# """Return a string usable as a multipart boundary. +# +# The string chosen is unique within a single program run, and +# incorporates the user id (if available), process id (if available), +# and current time. So it's very unlikely the returned string appears +# in message text, but there's no guarantee. +# +# The boundary contains dots so you have to quote it in the header.""" +# +# global _prefix +# import time +# if _prefix is None: +# import socket +# try: +# hostid = socket.gethostbyname(socket.gethostname()) +# except socket.gaierror: +# hostid = '127.0.0.1' +# try: +# uid = repr(os.getuid()) +# except AttributeError: +# uid = '1' +# try: +# pid = repr(os.getpid()) +# except AttributeError: +# pid = '1' +# _prefix = hostid + '.' + uid + '.' + pid +# return "%s.%.3f.%d" % (_prefix, time.time(), _get_next_counter()) + + +# Subroutines for decoding some common content-transfer-types + +def decode(input, output, encoding): + """Decode common content-transfer-encodings (base64, quopri, uuencode).""" + if encoding == 'base64': + import base64 + return base64.decode(input, output) + if encoding == 'quoted-printable': + import quopri + return quopri.decode(input, output) + if encoding in ('uuencode', 'x-uuencode', 'uue', 'x-uue'): + import uu + return uu.decode(input, output) + if encoding in ('7bit', '8bit'): + return output.write(input.read()) + if encoding in decodetab: + pipethrough(input, decodetab[encoding], output) + else: + raise ValueError, \ + 'unknown Content-Transfer-Encoding: %s' % encoding + +def encode(input, output, encoding): + """Encode common content-transfer-encodings (base64, quopri, uuencode).""" + if encoding == 'base64': + import base64 + return base64.encode(input, output) + if encoding == 'quoted-printable': + import quopri + return quopri.encode(input, output, 0) + if encoding in ('uuencode', 'x-uuencode', 'uue', 'x-uue'): + import uu + return uu.encode(input, output) + if encoding in ('7bit', '8bit'): + return output.write(input.read()) + if encoding in encodetab: + pipethrough(input, encodetab[encoding], output) + else: + raise ValueError, \ + 'unknown Content-Transfer-Encoding: %s' % encoding + +# The following is no longer used for standard encodings + +# XXX This requires that uudecode and mmencode are in $PATH + +uudecode_pipe = '''( +TEMP=/tmp/@uu.$$ +sed "s%^begin [0-7][0-7]* .*%begin 600 $TEMP%" | uudecode +cat $TEMP +rm $TEMP +)''' + +decodetab = { + 'uuencode': uudecode_pipe, + 'x-uuencode': uudecode_pipe, + 'uue': uudecode_pipe, + 'x-uue': uudecode_pipe, + 'quoted-printable': 'mmencode -u -q', + 'base64': 'mmencode -u -b', +} + +encodetab = { + 'x-uuencode': 'uuencode tempfile', + 'uuencode': 'uuencode tempfile', + 'x-uue': 'uuencode tempfile', + 'uue': 'uuencode tempfile', + 'quoted-printable': 'mmencode -q', + 'base64': 'mmencode -b', +} + +def pipeto(input, command): + pipe = os.popen(command, 'w') + copyliteral(input, pipe) + pipe.close() + +def pipethrough(input, command, output): + (fd, tempname) = tempfile.mkstemp() + temp = os.fdopen(fd, 'w') + copyliteral(input, temp) + temp.close() + pipe = os.popen(command + ' <' + tempname, 'r') + copybinary(pipe, output) + pipe.close() + os.unlink(tempname) + +def copyliteral(input, output): + while 1: + line = input.readline() + if not line: break + output.write(line) + +def copybinary(input, output): + BUFSIZE = 8192 + while 1: + line = input.read(BUFSIZE) + if not line: break + output.write(line) diff --git a/third_party/stdlib/mutex.py b/third_party/stdlib/mutex.py new file mode 100644 index 00000000..beb6e653 --- /dev/null +++ b/third_party/stdlib/mutex.py @@ -0,0 +1,55 @@ +"""Mutual exclusion -- for use with module sched + +A mutex has two pieces of state -- a 'locked' bit and a queue. +When the mutex is not locked, the queue is empty. +Otherwise, the queue contains 0 or more (function, argument) pairs +representing functions (or methods) waiting to acquire the lock. +When the mutex is unlocked while the queue is not empty, +the first queue entry is removed and its function(argument) pair called, +implying it now has the lock. + +Of course, no multi-threading is implied -- hence the funny interface +for lock, where a function is called once the lock is acquired. +""" +from warnings import warnpy3k +warnpy3k("the mutex module has been removed in Python 3.0", stacklevel=2) +del warnpy3k + +from collections import deque + +class mutex(object): + def __init__(self): + """Create a new mutex -- initially unlocked.""" + self.locked = False + self.queue = deque() + + def test(self): + """Test the locked bit of the mutex.""" + return self.locked + + def testandset(self): + """Atomic test-and-set -- grab the lock if it is not set, + return True if it succeeded.""" + if not self.locked: + self.locked = True + return True + else: + return False + + def lock(self, function, argument): + """Lock a mutex, call the function with supplied argument + when it is acquired. If the mutex is already locked, place + function and argument in the queue.""" + if self.testandset(): + function(argument) + else: + self.queue.append((function, argument)) + + def unlock(self): + """Unlock a mutex. If the queue is not empty, call the next + function with its argument.""" + if self.queue: + function, argument = self.queue.popleft() + function(argument) + else: + self.locked = False diff --git a/third_party/stdlib/quopri.py b/third_party/stdlib/quopri.py new file mode 100644 index 00000000..8788afc2 --- /dev/null +++ b/third_party/stdlib/quopri.py @@ -0,0 +1,237 @@ +#! /usr/bin/env python + +"""Conversions to/from quoted-printable transport encoding as per RFC 1521.""" + +# (Dec 1991 version). + +__all__ = ["encode", "decode", "encodestring", "decodestring"] + +ESCAPE = '=' +MAXLINESIZE = 76 +HEX = '0123456789ABCDEF' +EMPTYSTRING = '' + +try: + from binascii import a2b_qp, b2a_qp +except ImportError: + a2b_qp = None + b2a_qp = None + + +def needsquoting(c, quotetabs, header): + """Decide whether a particular character needs to be quoted. + + The 'quotetabs' flag indicates whether embedded tabs and spaces should be + quoted. Note that line-ending tabs and spaces are always encoded, as per + RFC 1521. + """ + if c in ' \t': + return quotetabs + # if header, we have to escape _ because _ is used to escape space + if c == '_': + return header + return c == ESCAPE or not (' ' <= c <= '~') + +def quote(c): + """Quote a single character.""" + i = ord(c) + return ESCAPE + HEX[i//16] + HEX[i%16] + + + +def encode(input, output, quotetabs, header = 0): + """Read 'input', apply quoted-printable encoding, and write to 'output'. + + 'input' and 'output' are files with readline() and write() methods. + The 'quotetabs' flag indicates whether embedded tabs and spaces should be + quoted. Note that line-ending tabs and spaces are always encoded, as per + RFC 1521. + The 'header' flag indicates whether we are encoding spaces as _ as per + RFC 1522. + """ + + if b2a_qp is not None: + data = input.read() + odata = b2a_qp(data, quotetabs = quotetabs, header = header) + output.write(odata) + return + + def write(s, output=output, lineEnd='\n'): + # RFC 1521 requires that the line ending in a space or tab must have + # that trailing character encoded. + if s and s[-1:] in ' \t': + output.write(s[:-1] + quote(s[-1]) + lineEnd) + elif s == '.': + output.write(quote(s) + lineEnd) + else: + output.write(s + lineEnd) + + prevline = None + while 1: + line = input.readline() + if not line: + break + outline = [] + # Strip off any readline induced trailing newline + stripped = '' + if line[-1:] == '\n': + line = line[:-1] + stripped = '\n' + # Calculate the un-length-limited encoded line + for c in line: + if needsquoting(c, quotetabs, header): + c = quote(c) + if header and c == ' ': + outline.append('_') + else: + outline.append(c) + # First, write out the previous line + if prevline is not None: + write(prevline) + # Now see if we need any soft line breaks because of RFC-imposed + # length limitations. Then do the thisline->prevline dance. + thisline = EMPTYSTRING.join(outline) + while len(thisline) > MAXLINESIZE: + # Don't forget to include the soft line break `=' sign in the + # length calculation! + write(thisline[:MAXLINESIZE-1], lineEnd='=\n') + thisline = thisline[MAXLINESIZE-1:] + # Write out the current line + prevline = thisline + # Write out the last line, without a trailing newline + if prevline is not None: + write(prevline, lineEnd=stripped) + +def encodestring(s, quotetabs = 0, header = 0): + if b2a_qp is not None: + return b2a_qp(s, quotetabs = quotetabs, header = header) + from cStringIO import StringIO + infp = StringIO(s) + outfp = StringIO() + encode(infp, outfp, quotetabs, header) + return outfp.getvalue() + + + +def decode(input, output, header = 0): + """Read 'input', apply quoted-printable decoding, and write to 'output'. + 'input' and 'output' are files with readline() and write() methods. + If 'header' is true, decode underscore as space (per RFC 1522).""" + + if a2b_qp is not None: + data = input.read() + odata = a2b_qp(data, header = header) + output.write(odata) + return + + new = '' + while 1: + line = input.readline() + if not line: break + i, n = 0, len(line) + if n > 0 and line[n-1] == '\n': + partial = 0; n = n-1 + # Strip trailing whitespace + while n > 0 and line[n-1] in " \t\r": + n = n-1 + else: + partial = 1 + while i < n: + c = line[i] + if c == '_' and header: + new = new + ' '; i = i+1 + elif c != ESCAPE: + new = new + c; i = i+1 + elif i+1 == n and not partial: + partial = 1; break + elif i+1 < n and line[i+1] == ESCAPE: + new = new + ESCAPE; i = i+2 + elif i+2 < n and ishex(line[i+1]) and ishex(line[i+2]): + new = new + chr(unhex(line[i+1:i+3])); i = i+3 + else: # Bad escape sequence -- leave it in + new = new + c; i = i+1 + if not partial: + output.write(new + '\n') + new = '' + if new: + output.write(new) + +def decodestring(s, header = 0): + if a2b_qp is not None: + return a2b_qp(s, header = header) + from cStringIO import StringIO + infp = StringIO(s) + outfp = StringIO() + decode(infp, outfp, header = header) + return outfp.getvalue() + + + +# Other helper functions +def ishex(c): + """Return true if the character 'c' is a hexadecimal digit.""" + return '0' <= c <= '9' or 'a' <= c <= 'f' or 'A' <= c <= 'F' + +def unhex(s): + """Get the integer value of a hexadecimal number.""" + bits = 0 + for c in s: + if '0' <= c <= '9': + i = ord('0') + elif 'a' <= c <= 'f': + i = ord('a')-10 + elif 'A' <= c <= 'F': + i = ord('A')-10 + else: + break + bits = bits*16 + (ord(c) - i) + return bits + + + +def main(): + import sys + import getopt + try: + opts, args = getopt.getopt(sys.argv[1:], 'td') + except getopt.error, msg: + sys.stdout = sys.stderr + print msg + print "usage: quopri [-t | -d] [file] ..." + print "-t: quote tabs" + print "-d: decode; default encode" + sys.exit(2) + deco = 0 + tabs = 0 + for o, a in opts: + if o == '-t': tabs = 1 + if o == '-d': deco = 1 + if tabs and deco: + sys.stdout = sys.stderr + print "-t and -d are mutually exclusive" + sys.exit(2) + if not args: args = ['-'] + sts = 0 + for file in args: + if file == '-': + fp = sys.stdin + else: + try: + fp = open(file) + except IOError, msg: + sys.stderr.write("%s: can't open (%s)\n" % (file, msg)) + sts = 1 + continue + if deco: + decode(fp, sys.stdout) + else: + encode(fp, sys.stdout, tabs) + if fp is not sys.stdin: + fp.close() + if sts: + sys.exit(sts) + + + +if __name__ == '__main__': + main() diff --git a/third_party/stdlib/rfc822.py b/third_party/stdlib/rfc822.py new file mode 100644 index 00000000..a69e40e6 --- /dev/null +++ b/third_party/stdlib/rfc822.py @@ -0,0 +1,1016 @@ +"""RFC 2822 message manipulation. + +Note: This is only a very rough sketch of a full RFC-822 parser; in particular +the tokenizing of addresses does not adhere to all the quoting rules. + +Note: RFC 2822 is a long awaited update to RFC 822. This module should +conform to RFC 2822, and is thus mis-named (it's not worth renaming it). Some +effort at RFC 2822 updates have been made, but a thorough audit has not been +performed. Consider any RFC 2822 non-conformance to be a bug. + + RFC 2822: http://www.faqs.org/rfcs/rfc2822.html + RFC 822 : http://www.faqs.org/rfcs/rfc822.html (obsolete) + +Directions for use: + +To create a Message object: first open a file, e.g.: + + fp = open(file, 'r') + +You can use any other legal way of getting an open file object, e.g. use +sys.stdin or call os.popen(). Then pass the open file object to the Message() +constructor: + + m = Message(fp) + +This class can work with any input object that supports a readline method. If +the input object has seek and tell capability, the rewindbody method will +work; also illegal lines will be pushed back onto the input stream. If the +input object lacks seek but has an `unread' method that can push back a line +of input, Message will use that to push back illegal lines. Thus this class +can be used to parse messages coming from a buffered stream. + +The optional `seekable' argument is provided as a workaround for certain stdio +libraries in which tell() discards buffered data before discovering that the +lseek() system call doesn't work. For maximum portability, you should set the +seekable argument to zero to prevent that initial \code{tell} when passing in +an unseekable object such as a file object created from a socket object. If +it is 1 on entry -- which it is by default -- the tell() method of the open +file object is called once; if this raises an exception, seekable is reset to +0. For other nonzero values of seekable, this test is not made. + +To get the text of a particular header there are several methods: + + str = m.getheader(name) + str = m.getrawheader(name) + +where name is the name of the header, e.g. 'Subject'. The difference is that +getheader() strips the leading and trailing whitespace, while getrawheader() +doesn't. Both functions retain embedded whitespace (including newlines) +exactly as they are specified in the header, and leave the case of the text +unchanged. + +For addresses and address lists there are functions + + realname, mailaddress = m.getaddr(name) + list = m.getaddrlist(name) + +where the latter returns a list of (realname, mailaddr) tuples. + +There is also a method + + time = m.getdate(name) + +which parses a Date-like field and returns a time-compatible tuple, +i.e. a tuple such as returned by time.localtime() or accepted by +time.mktime(). + +See the class definition for lower level access methods. + +There are also some utility functions here. +""" +# Cleanup and extensions by Eric S. Raymond + +import time + +from warnings import warnpy3k +warnpy3k("in 3.x, rfc822 has been removed in favor of the email package", + stacklevel=2) + +__all__ = ["Message","AddressList","parsedate","parsedate_tz","mktime_tz"] + +_blanklines = ('\r\n', '\n') # Optimization for islast() + + +class Message(object): + """Represents a single RFC 2822-compliant message.""" + + def __init__(self, fp, seekable = 1): + """Initialize the class instance and read the headers.""" + if seekable == 1: + # Exercise tell() to make sure it works + # (and then assume seek() works, too) + try: + fp.tell() + except (AttributeError, IOError): + seekable = 0 + self.fp = fp + self.seekable = seekable + self.startofheaders = None + self.startofbody = None + # + if self.seekable: + try: + self.startofheaders = self.fp.tell() + except IOError: + self.seekable = 0 + # + self.readheaders() + # + if self.seekable: + try: + self.startofbody = self.fp.tell() + except IOError: + self.seekable = 0 + + def rewindbody(self): + """Rewind the file to the start of the body (if seekable).""" + if not self.seekable: + raise IOError, "unseekable file" + self.fp.seek(self.startofbody) + + def readheaders(self): + """Read header lines. + + Read header lines up to the entirely blank line that terminates them. + The (normally blank) line that ends the headers is skipped, but not + included in the returned list. If a non-header line ends the headers, + (which is an error), an attempt is made to backspace over it; it is + never included in the returned list. + + The variable self.status is set to the empty string if all went well, + otherwise it is an error message. The variable self.headers is a + completely uninterpreted list of lines contained in the header (so + printing them will reproduce the header exactly as it appears in the + file). + """ + self.dict = {} + self.unixfrom = '' + self.headers = lst = [] + self.status = '' + headerseen = "" + firstline = 1 + startofline = unread = tell = None + if hasattr(self.fp, 'unread'): + unread = self.fp.unread + elif self.seekable: + tell = self.fp.tell + while 1: + if tell: + try: + startofline = tell() + except IOError: + startofline = tell = None + self.seekable = 0 + line = self.fp.readline() + if not line: + self.status = 'EOF in headers' + break + # Skip unix From name time lines + if firstline and line.startswith('From '): + self.unixfrom = self.unixfrom + line + continue + firstline = 0 + if headerseen and line[0] in ' \t': + # It's a continuation line. + lst.append(line) + x = (self.dict[headerseen] + "\n " + line.strip()) + self.dict[headerseen] = x.strip() + continue + elif self.iscomment(line): + # It's a comment. Ignore it. + continue + elif self.islast(line): + # Note! No pushback here! The delimiter line gets eaten. + break + headerseen = self.isheader(line) + if headerseen: + # It's a legal header line, save it. + lst.append(line) + self.dict[headerseen] = line[len(headerseen)+1:].strip() + continue + elif headerseen is not None: + # An empty header name. These aren't allowed in HTTP, but it's + # probably a benign mistake. Don't add the header, just keep + # going. + continue + else: + # It's not a header line; throw it back and stop here. + if not self.dict: + self.status = 'No headers' + else: + self.status = 'Non-header line where header expected' + # Try to undo the read. + if unread: + unread(line) + elif tell: + self.fp.seek(startofline) + else: + self.status = self.status + '; bad seek' + break + + def isheader(self, line): + """Determine whether a given line is a legal header. + + This method should return the header name, suitably canonicalized. + You may override this method in order to use Message parsing on tagged + data in RFC 2822-like formats with special header formats. + """ + i = line.find(':') + if i > -1: + return line[:i].lower() + return None + + def islast(self, line): + """Determine whether a line is a legal end of RFC 2822 headers. + + You may override this method if your application wants to bend the + rules, e.g. to strip trailing whitespace, or to recognize MH template + separators ('--------'). For convenience (e.g. for code reading from + sockets) a line consisting of \\r\\n also matches. + """ + return line in _blanklines + + def iscomment(self, line): + """Determine whether a line should be skipped entirely. + + You may override this method in order to use Message parsing on tagged + data in RFC 2822-like formats that support embedded comments or + free-text data. + """ + return False + + def getallmatchingheaders(self, name): + """Find all header lines matching a given header name. + + Look through the list of headers and find all lines matching a given + header name (and their continuation lines). A list of the lines is + returned, without interpretation. If the header does not occur, an + empty list is returned. If the header occurs multiple times, all + occurrences are returned. Case is not important in the header name. + """ + name = name.lower() + ':' + n = len(name) + lst = [] + hit = 0 + for line in self.headers: + if line[:n].lower() == name: + hit = 1 + elif not line[:1].isspace(): + hit = 0 + if hit: + lst.append(line) + return lst + + def getfirstmatchingheader(self, name): + """Get the first header line matching name. + + This is similar to getallmatchingheaders, but it returns only the + first matching header (and its continuation lines). + """ + name = name.lower() + ':' + n = len(name) + lst = [] + hit = 0 + for line in self.headers: + if hit: + if not line[:1].isspace(): + break + elif line[:n].lower() == name: + hit = 1 + if hit: + lst.append(line) + return lst + + def getrawheader(self, name): + """A higher-level interface to getfirstmatchingheader(). + + Return a string containing the literal text of the header but with the + keyword stripped. All leading, trailing and embedded whitespace is + kept in the string, however. Return None if the header does not + occur. + """ + + lst = self.getfirstmatchingheader(name) + if not lst: + return None + lst[0] = lst[0][len(name) + 1:] + return ''.join(lst) + + def getheader(self, name, default=None): + """Get the header value for a name. + + This is the normal interface: it returns a stripped version of the + header value for a given header name, or None if it doesn't exist. + This uses the dictionary version which finds the *last* such header. + """ + return self.dict.get(name.lower(), default) + get = getheader + + def getheaders(self, name): + """Get all values for a header. + + This returns a list of values for headers given more than once; each + value in the result list is stripped in the same way as the result of + getheader(). If the header is not given, return an empty list. + """ + result = [] + current = '' + have_header = 0 + for s in self.getallmatchingheaders(name): + if s[0].isspace(): + if current: + current = "%s\n %s" % (current, s.strip()) + else: + current = s.strip() + else: + if have_header: + result.append(current) + current = s[s.find(":") + 1:].strip() + have_header = 1 + if have_header: + result.append(current) + return result + + def getaddr(self, name): + """Get a single address from a header, as a tuple. + + An example return value: + ('Guido van Rossum', 'guido@cwi.nl') + """ + # New, by Ben Escoto + alist = self.getaddrlist(name) + if alist: + return alist[0] + else: + return (None, None) + + def getaddrlist(self, name): + """Get a list of addresses from a header. + + Retrieves a list of addresses from a header, where each address is a + tuple as returned by getaddr(). Scans all named headers, so it works + properly with multiple To: or Cc: headers for example. + """ + raw = [] + for h in self.getallmatchingheaders(name): + if h[0] in ' \t': + raw.append(h) + else: + if raw: + raw.append(', ') + i = h.find(':') + if i > 0: + addr = h[i+1:] + raw.append(addr) + alladdrs = ''.join(raw) + a = AddressList(alladdrs) + return a.addresslist + + def getdate(self, name): + """Retrieve a date field from a header. + + Retrieves a date field from the named header, returning a tuple + compatible with time.mktime(). + """ + try: + data = self[name] + except KeyError: + return None + return parsedate(data) + + def getdate_tz(self, name): + """Retrieve a date field from a header as a 10-tuple. + + The first 9 elements make up a tuple compatible with time.mktime(), + and the 10th is the offset of the poster's time zone from GMT/UTC. + """ + try: + data = self[name] + except KeyError: + return None + return parsedate_tz(data) + + + # Access as a dictionary (only finds *last* header of each type): + + def __len__(self): + """Get the number of headers in a message.""" + return len(self.dict) + + def __getitem__(self, name): + """Get a specific header, as from a dictionary.""" + return self.dict[name.lower()] + + def __setitem__(self, name, value): + """Set the value of a header. + + Note: This is not a perfect inversion of __getitem__, because any + changed headers get stuck at the end of the raw-headers list rather + than where the altered header was. + """ + del self[name] # Won't fail if it doesn't exist + self.dict[name.lower()] = value + text = name + ": " + value + for line in text.split("\n"): + self.headers.append(line + "\n") + + def __delitem__(self, name): + """Delete all occurrences of a specific header, if it is present.""" + name = name.lower() + if not name in self.dict: + return + del self.dict[name] + name = name + ':' + n = len(name) + lst = [] + hit = 0 + for i in range(len(self.headers)): + line = self.headers[i] + if line[:n].lower() == name: + hit = 1 + elif not line[:1].isspace(): + hit = 0 + if hit: + lst.append(i) + for i in reversed(lst): + del self.headers[i] + + def setdefault(self, name, default=""): + lowername = name.lower() + if lowername in self.dict: + return self.dict[lowername] + else: + text = name + ": " + default + for line in text.split("\n"): + self.headers.append(line + "\n") + self.dict[lowername] = default + return default + + def has_key(self, name): + """Determine whether a message contains the named header.""" + return name.lower() in self.dict + + def __contains__(self, name): + """Determine whether a message contains the named header.""" + return name.lower() in self.dict + + def __iter__(self): + return iter(self.dict) + + def keys(self): + """Get all of a message's header field names.""" + return self.dict.keys() + + def values(self): + """Get all of a message's header field values.""" + return self.dict.values() + + def items(self): + """Get all of a message's headers. + + Returns a list of name, value tuples. + """ + return self.dict.items() + + def __str__(self): + return ''.join(self.headers) + + +# Utility functions +# ----------------- + +# XXX Should fix unquote() and quote() to be really conformant. +# XXX The inverses of the parse functions may also be useful. + + +def unquote(s): + """Remove quotes from a string.""" + if len(s) > 1: + if s.startswith('"') and s.endswith('"'): + return s[1:-1].replace('\\\\', '\\').replace('\\"', '"') + if s.startswith('<') and s.endswith('>'): + return s[1:-1] + return s + + +def quote(s): + """Add quotes around a string.""" + return s.replace('\\', '\\\\').replace('"', '\\"') + + +def parseaddr(address): + """Parse an address into a (realname, mailaddr) tuple.""" + a = AddressList(address) + lst = a.addresslist + if not lst: + return (None, None) + return lst[0] + + +class AddrlistClass(object): + """Address parser class by Ben Escoto. + + To understand what this class does, it helps to have a copy of + RFC 2822 in front of you. + + http://www.faqs.org/rfcs/rfc2822.html + + Note: this class interface is deprecated and may be removed in the future. + Use rfc822.AddressList instead. + """ + + def __init__(self, field): + """Initialize a new instance. + + `field' is an unparsed address header field, containing one or more + addresses. + """ + self.specials = '()<>@,:;.\"[]' + self.pos = 0 + self.LWS = ' \t' + self.CR = '\r\n' + self.atomends = self.specials + self.LWS + self.CR + # Note that RFC 2822 now specifies `.' as obs-phrase, meaning that it + # is obsolete syntax. RFC 2822 requires that we recognize obsolete + # syntax, so allow dots in phrases. + self.phraseends = self.atomends.replace('.', '') + self.field = field + self.commentlist = [] + + def gotonext(self): + """Parse up to the start of the next address.""" + while self.pos < len(self.field): + if self.field[self.pos] in self.LWS + '\n\r': + self.pos = self.pos + 1 + elif self.field[self.pos] == '(': + self.commentlist.append(self.getcomment()) + else: break + + def getaddrlist(self): + """Parse all addresses. + + Returns a list containing all of the addresses. + """ + result = [] + ad = self.getaddress() + while ad: + result += ad + ad = self.getaddress() + return result + + def getaddress(self): + """Parse the next address.""" + self.commentlist = [] + self.gotonext() + + oldpos = self.pos + oldcl = self.commentlist + plist = self.getphraselist() + + self.gotonext() + returnlist = [] + + if self.pos >= len(self.field): + # Bad email address technically, no domain. + if plist: + returnlist = [(' '.join(self.commentlist), plist[0])] + + elif self.field[self.pos] in '.@': + # email address is just an addrspec + # this isn't very efficient since we start over + self.pos = oldpos + self.commentlist = oldcl + addrspec = self.getaddrspec() + returnlist = [(' '.join(self.commentlist), addrspec)] + + elif self.field[self.pos] == ':': + # address is a group + returnlist = [] + + fieldlen = len(self.field) + self.pos += 1 + while self.pos < len(self.field): + self.gotonext() + if self.pos < fieldlen and self.field[self.pos] == ';': + self.pos += 1 + break + returnlist = returnlist + self.getaddress() + + elif self.field[self.pos] == '<': + # Address is a phrase then a route addr + routeaddr = self.getrouteaddr() + + if self.commentlist: + returnlist = [(' '.join(plist) + ' (' + \ + ' '.join(self.commentlist) + ')', routeaddr)] + else: returnlist = [(' '.join(plist), routeaddr)] + + else: + if plist: + returnlist = [(' '.join(self.commentlist), plist[0])] + elif self.field[self.pos] in self.specials: + self.pos += 1 + + self.gotonext() + if self.pos < len(self.field) and self.field[self.pos] == ',': + self.pos += 1 + return returnlist + + def getrouteaddr(self): + """Parse a route address (Return-path value). + + This method just skips all the route stuff and returns the addrspec. + """ + if self.field[self.pos] != '<': + return + + expectroute = 0 + self.pos += 1 + self.gotonext() + adlist = "" + while self.pos < len(self.field): + if expectroute: + self.getdomain() + expectroute = 0 + elif self.field[self.pos] == '>': + self.pos += 1 + break + elif self.field[self.pos] == '@': + self.pos += 1 + expectroute = 1 + elif self.field[self.pos] == ':': + self.pos += 1 + else: + adlist = self.getaddrspec() + self.pos += 1 + break + self.gotonext() + + return adlist + + def getaddrspec(self): + """Parse an RFC 2822 addr-spec.""" + aslist = [] + + self.gotonext() + while self.pos < len(self.field): + if self.field[self.pos] == '.': + aslist.append('.') + self.pos += 1 + elif self.field[self.pos] == '"': + aslist.append('"%s"' % self.getquote()) + elif self.field[self.pos] in self.atomends: + break + else: aslist.append(self.getatom()) + self.gotonext() + + if self.pos >= len(self.field) or self.field[self.pos] != '@': + return ''.join(aslist) + + aslist.append('@') + self.pos += 1 + self.gotonext() + return ''.join(aslist) + self.getdomain() + + def getdomain(self): + """Get the complete domain name from an address.""" + sdlist = [] + while self.pos < len(self.field): + if self.field[self.pos] in self.LWS: + self.pos += 1 + elif self.field[self.pos] == '(': + self.commentlist.append(self.getcomment()) + elif self.field[self.pos] == '[': + sdlist.append(self.getdomainliteral()) + elif self.field[self.pos] == '.': + self.pos += 1 + sdlist.append('.') + elif self.field[self.pos] in self.atomends: + break + else: sdlist.append(self.getatom()) + return ''.join(sdlist) + + def getdelimited(self, beginchar, endchars, allowcomments = 1): + """Parse a header fragment delimited by special characters. + + `beginchar' is the start character for the fragment. If self is not + looking at an instance of `beginchar' then getdelimited returns the + empty string. + + `endchars' is a sequence of allowable end-delimiting characters. + Parsing stops when one of these is encountered. + + If `allowcomments' is non-zero, embedded RFC 2822 comments are allowed + within the parsed fragment. + """ + if self.field[self.pos] != beginchar: + return '' + + slist = [''] + quote = 0 + self.pos += 1 + while self.pos < len(self.field): + if quote == 1: + slist.append(self.field[self.pos]) + quote = 0 + elif self.field[self.pos] in endchars: + self.pos += 1 + break + elif allowcomments and self.field[self.pos] == '(': + slist.append(self.getcomment()) + continue # have already advanced pos from getcomment + elif self.field[self.pos] == '\\': + quote = 1 + else: + slist.append(self.field[self.pos]) + self.pos += 1 + + return ''.join(slist) + + def getquote(self): + """Get a quote-delimited fragment from self's field.""" + return self.getdelimited('"', '"\r', 0) + + def getcomment(self): + """Get a parenthesis-delimited fragment from self's field.""" + return self.getdelimited('(', ')\r', 1) + + def getdomainliteral(self): + """Parse an RFC 2822 domain-literal.""" + return '[%s]' % self.getdelimited('[', ']\r', 0) + + def getatom(self, atomends=None): + """Parse an RFC 2822 atom. + + Optional atomends specifies a different set of end token delimiters + (the default is to use self.atomends). This is used e.g. in + getphraselist() since phrase endings must not include the `.' (which + is legal in phrases).""" + atomlist = [''] + if atomends is None: + atomends = self.atomends + + while self.pos < len(self.field): + if self.field[self.pos] in atomends: + break + else: atomlist.append(self.field[self.pos]) + self.pos += 1 + + return ''.join(atomlist) + + def getphraselist(self): + """Parse a sequence of RFC 2822 phrases. + + A phrase is a sequence of words, which are in turn either RFC 2822 + atoms or quoted-strings. Phrases are canonicalized by squeezing all + runs of continuous whitespace into one space. + """ + plist = [] + + while self.pos < len(self.field): + if self.field[self.pos] in self.LWS: + self.pos += 1 + elif self.field[self.pos] == '"': + plist.append(self.getquote()) + elif self.field[self.pos] == '(': + self.commentlist.append(self.getcomment()) + elif self.field[self.pos] in self.phraseends: + break + else: + plist.append(self.getatom(self.phraseends)) + + return plist + +class AddressList(AddrlistClass): + """An AddressList encapsulates a list of parsed RFC 2822 addresses.""" + def __init__(self, field): + AddrlistClass.__init__(self, field) + if field: + self.addresslist = self.getaddrlist() + else: + self.addresslist = [] + + def __len__(self): + return len(self.addresslist) + + def __str__(self): + return ", ".join(map(dump_address_pair, self.addresslist)) + + def __add__(self, other): + # Set union + newaddr = AddressList(None) + newaddr.addresslist = self.addresslist[:] + for x in other.addresslist: + if not x in self.addresslist: + newaddr.addresslist.append(x) + return newaddr + + def __iadd__(self, other): + # Set union, in-place + for x in other.addresslist: + if not x in self.addresslist: + self.addresslist.append(x) + return self + + def __sub__(self, other): + # Set difference + newaddr = AddressList(None) + for x in self.addresslist: + if not x in other.addresslist: + newaddr.addresslist.append(x) + return newaddr + + def __isub__(self, other): + # Set difference, in-place + for x in other.addresslist: + if x in self.addresslist: + self.addresslist.remove(x) + return self + + def __getitem__(self, index): + # Make indexing, slices, and 'in' work + return self.addresslist[index] + +def dump_address_pair(pair): + """Dump a (name, address) pair in a canonicalized form.""" + if pair[0]: + return '"' + pair[0] + '" <' + pair[1] + '>' + else: + return pair[1] + +# Parse a date field + +_monthnames = ['jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', + 'aug', 'sep', 'oct', 'nov', 'dec', + 'january', 'february', 'march', 'april', 'may', 'june', 'july', + 'august', 'september', 'october', 'november', 'december'] +_daynames = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun'] + +# The timezone table does not include the military time zones defined +# in RFC822, other than Z. According to RFC1123, the description in +# RFC822 gets the signs wrong, so we can't rely on any such time +# zones. RFC1123 recommends that numeric timezone indicators be used +# instead of timezone names. + +_timezones = {'UT':0, 'UTC':0, 'GMT':0, 'Z':0, + 'AST': -400, 'ADT': -300, # Atlantic (used in Canada) + 'EST': -500, 'EDT': -400, # Eastern + 'CST': -600, 'CDT': -500, # Central + 'MST': -700, 'MDT': -600, # Mountain + 'PST': -800, 'PDT': -700 # Pacific + } + + +def parsedate_tz(data): + """Convert a date string to a time tuple. + + Accounts for military timezones. + """ + if not data: + return None + data = data.split() + if data[0][-1] in (',', '.') or data[0].lower() in _daynames: + # There's a dayname here. Skip it + del data[0] + else: + # no space after the "weekday,"? + i = data[0].rfind(',') + if i >= 0: + data[0] = data[0][i+1:] + if len(data) == 3: # RFC 850 date, deprecated + stuff = data[0].split('-') + if len(stuff) == 3: + data = stuff + data[1:] + if len(data) == 4: + s = data[3] + i = s.find('+') + if i > 0: + data[3:] = [s[:i], s[i+1:]] + else: + data.append('') # Dummy tz + if len(data) < 5: + return None + data = data[:5] + [dd, mm, yy, tm, tz] = data + mm = mm.lower() + if not mm in _monthnames: + dd, mm = mm, dd.lower() + if not mm in _monthnames: + return None + mm = _monthnames.index(mm)+1 + if mm > 12: mm = mm - 12 + if dd[-1] == ',': + dd = dd[:-1] + i = yy.find(':') + if i > 0: + yy, tm = tm, yy + if yy[-1] == ',': + yy = yy[:-1] + if not yy[0].isdigit(): + yy, tz = tz, yy + if tm[-1] == ',': + tm = tm[:-1] + tm = tm.split(':') + if len(tm) == 2: + [thh, tmm] = tm + tss = '0' + elif len(tm) == 3: + [thh, tmm, tss] = tm + else: + return None + try: + yy = int(yy) + dd = int(dd) + thh = int(thh) + tmm = int(tmm) + tss = int(tss) + except ValueError: + return None + tzoffset = None + tz = tz.upper() + if tz in _timezones: + tzoffset = _timezones[tz] + else: + try: + tzoffset = int(tz) + except ValueError: + pass + # Convert a timezone offset into seconds ; -0500 -> -18000 + if tzoffset: + if tzoffset < 0: + tzsign = -1 + tzoffset = -tzoffset + else: + tzsign = 1 + tzoffset = tzsign * ( (tzoffset//100)*3600 + (tzoffset % 100)*60) + return (yy, mm, dd, thh, tmm, tss, 0, 1, 0, tzoffset) + + +def parsedate(data): + """Convert a time string to a time tuple.""" + t = parsedate_tz(data) + if t is None: + return t + return t[:9] + + +def mktime_tz(data): + """Turn a 10-tuple as returned by parsedate_tz() into a UTC timestamp.""" + if data[9] is None: + # No zone info, so localtime is better assumption than GMT + return time.mktime(data[:8] + (-1,)) + else: + t = time.mktime(data[:8] + (0,)) + return t - data[9] - time.timezone + +def formatdate(timeval=None): + """Returns time format preferred for Internet standards. + + Sun, 06 Nov 1994 08:49:37 GMT ; RFC 822, updated by RFC 1123 + + According to RFC 1123, day and month names must always be in + English. If not for that, this code could use strftime(). It + can't because strftime() honors the locale and could generate + non-English names. + """ + if timeval is None: + timeval = time.time() + timeval = time.gmtime(timeval) + return "%s, %02d %s %04d %02d:%02d:%02d GMT" % ( + ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")[timeval[6]], + timeval[2], + ("Jan", "Feb", "Mar", "Apr", "May", "Jun", + "Jul", "Aug", "Sep", "Oct", "Nov", "Dec")[timeval[1]-1], + timeval[0], timeval[3], timeval[4], timeval[5]) + + +# When used as script, run a small test program. +# The first command line argument must be a filename containing one +# message in RFC-822 format. + +if __name__ == '__main__': + import sys, os + file = os.path.join(os.environ['HOME'], 'Mail/inbox/1') + if sys.argv[1:]: file = sys.argv[1] + f = open(file, 'r') + m = Message(f) + print 'From:', m.getaddr('from') + print 'To:', m.getaddrlist('to') + print 'Subject:', m.getheader('subject') + print 'Date:', m.getheader('date') + date = m.getdate_tz('date') + tz = date[-1] + date = time.localtime(mktime_tz(date)) + if date: + print 'ParsedDate:', time.asctime(date), + hhmmss = tz + hhmm, ss = divmod(hhmmss, 60) + hh, mm = divmod(hhmm, 60) + print "%+03d%02d" % (hh, mm), + if ss: print ".%02d" % ss, + print + else: + print 'ParsedDate:', None + m.rewindbody() + n = 0 + while f.readline(): + n += 1 + print 'Lines:', n + print '-'*70 + print 'len =', len(m) + if 'Date' in m: print 'Date =', m['Date'] + if 'X-Nonsense' in m: pass + print 'keys =', m.keys() + print 'values =', m.values() + print 'items =', m.items() diff --git a/third_party/stdlib/sched.py b/third_party/stdlib/sched.py new file mode 100644 index 00000000..1c5d944f --- /dev/null +++ b/third_party/stdlib/sched.py @@ -0,0 +1,142 @@ +"""A generally useful event scheduler class. +Each instance of this class manages its own queue. +No multi-threading is implied; you are supposed to hack that +yourself, or use a single instance per application. +Each instance is parametrized with two functions, one that is +supposed to return the current time, one that is supposed to +implement a delay. You can implement real-time scheduling by +substituting time and sleep from built-in module time, or you can +implement simulated time by writing your own functions. This can +also be used to integrate scheduling with STDWIN events; the delay +function is allowed to modify the queue. Time can be expressed as +integers or floating point numbers, as long as it is consistent. +Events are specified by tuples (time, priority, action, argument). +As in UNIX, lower priority numbers mean higher priority; in this +way the queue can be maintained as a priority queue. Execution of the +event means calling the action function, passing it the argument +sequence in "argument" (remember that in Python, multiple function +arguments are be packed in a sequence). +The action function may be an instance method so it +has another way to reference private data (besides global variables). +""" + +# XXX The timefunc and delayfunc should have been defined as methods +# XXX so you can define new kinds of schedulers using subclassing +# XXX instead of having to define a module or class just to hold +# XXX the global state of your particular time and delay functions. + +import heapq +# TODO: grumpy modified version +#from collections import namedtuple + +__all__ = ["scheduler"] + +# TODO: Use namedtuple +# Event = namedtuple('Event', 'time, priority, action, argument') + +class Event(object): + + __slots__ = ['time', 'priority', 'action', 'argument'] + + def __init__(self, time, priority, action, argument): + self.time = time + self.priority = priority + self.action = action + self.argument = argument + + def get_fields(self): + return (self.time, self.priority, self.action, self.argument) + + def __eq__(s, o): return (s.time, s.priority) == (o.time, o.priority) + def __lt__(s, o): return (s.time, s.priority) < (o.time, o.priority) + def __le__(s, o): return (s.time, s.priority) <= (o.time, o.priority) + def __gt__(s, o): return (s.time, s.priority) > (o.time, o.priority) + def __ge__(s, o): return (s.time, s.priority) >= (o.time, o.priority) + +class scheduler(object): + def __init__(self, timefunc, delayfunc): + """Initialize a new instance, passing the time and delay + functions""" + self._queue = [] + self.timefunc = timefunc + self.delayfunc = delayfunc + + def enterabs(self, time, priority, action, argument): + """Enter a new event in the queue at an absolute time. + Returns an ID for the event which can be used to remove it, + if necessary. + """ + event = Event(time, priority, action, argument) + heapq.heappush(self._queue, event) + return event # The ID + + def enter(self, delay, priority, action, argument): + """A variant that specifies the time as a relative time. + This is actually the more commonly used interface. + """ + time = self.timefunc() + delay + return self.enterabs(time, priority, action, argument) + + def cancel(self, event): + """Remove an event from the queue. + This must be presented the ID as returned by enter(). + If the event is not in the queue, this raises ValueError. + """ + self._queue.remove(event) + heapq.heapify(self._queue) + + def empty(self): + """Check whether the queue is empty.""" + return not self._queue + + def run(self): + """Execute events until the queue is empty. + When there is a positive delay until the first event, the + delay function is called and the event is left in the queue; + otherwise, the event is removed from the queue and executed + (its action function is called, passing it the argument). If + the delay function returns prematurely, it is simply + restarted. + It is legal for both the delay function and the action + function to modify the queue or to raise an exception; + exceptions are not caught but the scheduler's state remains + well-defined so run() may be called again. + A questionable hack is added to allow other threads to run: + just after an event is executed, a delay of 0 is executed, to + avoid monopolizing the CPU when other threads are also + runnable. + """ + # localize variable access to minimize overhead + # and to improve thread safety + q = self._queue + delayfunc = self.delayfunc + timefunc = self.timefunc + pop = heapq.heappop + while q: + # TODO: modified part of grumpy version. + checked_event = q[0] + time, priority, action, argument = checked_event.get_fields() + now = timefunc() + if now < time: + delayfunc(time - now) + else: + event = pop(q) + # Verify that the event was not removed or altered + # by another thread after we last looked at q[0]. + if event is checked_event: + action(*argument) + delayfunc(0) # Let other threads run + else: + heapq.heappush(q, event) + + @property + def queue(self): + """An ordered list of upcoming events. + Events are named tuples with fields for: + time, priority, action, arguments + """ + # Use heapq to sort the queue rather than using 'sorted(self._queue)'. + # With heapq, two events scheduled at the same time will show in + # the actual order they would be retrieved. + events = self._queue[:] + return map(heapq.heappop, [events]*len(events)) diff --git a/third_party/stdlib/test/list_tests.py b/third_party/stdlib/test/list_tests.py index 0786914f..b468b682 100644 --- a/third_party/stdlib/test/list_tests.py +++ b/third_party/stdlib/test/list_tests.py @@ -9,60 +9,60 @@ class CommonTest(seq_tests.CommonTest): - def test_init(self): - # Iterable arg is optional - self.assertEqual(self.type2test([]), self.type2test()) - - # Init clears previous values - a = self.type2test([1, 2, 3]) - a.__init__() - self.assertEqual(a, self.type2test([])) - - # Init overwrites previous values - a = self.type2test([1, 2, 3]) - a.__init__([4, 5, 6]) - self.assertEqual(a, self.type2test([4, 5, 6])) - - # Mutables always return a new object - b = self.type2test(a) - self.assertNotEqual(id(a), id(b)) - self.assertEqual(a, b) - - def test_repr(self): - l0 = [] - l2 = [0, 1, 2] - a0 = self.type2test(l0) - a2 = self.type2test(l2) - - self.assertEqual(str(a0), str(l0)) - self.assertEqual(repr(a0), repr(l0)) - self.assertEqual(repr(a2), repr(l2)) - self.assertEqual(str(a2), "[0, 1, 2]") - self.assertEqual(repr(a2), "[0, 1, 2]") - - a2.append(a2) - a2.append(3) - self.assertEqual(str(a2), "[0, 1, 2, [...], 3]") - self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]") - - l0 = [] - for i in xrange(sys.getrecursionlimit() + 100): - l0 = [l0] - self.assertRaises(RuntimeError, repr, l0) - - def test_print(self): - d = self.type2test(xrange(200)) - d.append(d) - d.extend(xrange(200,400)) - d.append(d) - d.append(400) - try: - with open(test_support.TESTFN, "wb") as fo: - print >> fo, d, - with open(test_support.TESTFN, "rb") as fo: - self.assertEqual(fo.read(), repr(d)) - finally: - os.remove(test_support.TESTFN) +# def test_init(self): +# # Iterable arg is optional +# self.assertEqual(self.type2test([]), self.type2test()) +# +# # Init clears previous values +# a = self.type2test([1, 2, 3]) +# a.__init__() +# self.assertEqual(a, self.type2test([])) +# +# # Init overwrites previous values +# a = self.type2test([1, 2, 3]) +# a.__init__([4, 5, 6]) +# self.assertEqual(a, self.type2test([4, 5, 6])) +# +# # Mutables always return a new object +# b = self.type2test(a) +# self.assertNotEqual(id(a), id(b)) +# self.assertEqual(a, b) + +# def test_repr(self): +# l0 = [] +# l2 = [0, 1, 2] +# a0 = self.type2test(l0) +# a2 = self.type2test(l2) +# +# self.assertEqual(str(a0), str(l0)) +# self.assertEqual(repr(a0), repr(l0)) +# self.assertEqual(repr(a2), repr(l2)) +# self.assertEqual(str(a2), "[0, 1, 2]") +# self.assertEqual(repr(a2), "[0, 1, 2]") +# +# a2.append(a2) +# a2.append(3) +# self.assertEqual(str(a2), "[0, 1, 2, [...], 3]") +# self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]") +# +# l0 = [] +# for i in xrange(sys.getrecursionlimit() + 100): +# l0 = [l0] +# self.assertRaises(RuntimeError, repr, l0) + +# def test_print(self): +# d = self.type2test(xrange(200)) +# d.append(d) +# d.extend(xrange(200,400)) +# d.append(d) +# d.append(400) +# try: +# with open(test_support.TESTFN, "wb") as fo: +# print >> fo, d, +# with open(test_support.TESTFN, "rb") as fo: +# self.assertEqual(fo.read(), repr(d)) +# finally: +# os.remove(test_support.TESTFN) def test_set_subscript(self): a = self.type2test(range(20)) @@ -75,15 +75,15 @@ def test_set_subscript(self): 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])) - def test_reversed(self): - a = self.type2test(range(20)) - r = reversed(a) - self.assertEqual(list(r), self.type2test(range(19, -1, -1))) - self.assertRaises(StopIteration, r.next) - self.assertEqual(list(reversed(self.type2test())), - self.type2test()) - # Bug 3689: make sure list-reversed-iterator doesn't have __len__ - self.assertRaises(TypeError, len, reversed([1,2,3])) +# def test_reversed(self): +# a = self.type2test(range(20)) +# r = reversed(a) +# self.assertEqual(list(r), self.type2test(range(19, -1, -1))) +# self.assertRaises(StopIteration, r.next) +# self.assertEqual(list(reversed(self.type2test())), +# self.type2test()) +# # Bug 3689: make sure list-reversed-iterator doesn't have __len__ +# self.assertRaises(TypeError, len, reversed([1,2,3])) def test_setitem(self): a = self.type2test([0, 1]) @@ -140,53 +140,53 @@ def test_delitem(self): self.assertRaises(TypeError, a.__delitem__) - def test_setslice(self): - l = [0, 1] - a = self.type2test(l) - - for i in range(-3, 4): - a[:i] = l[:i] - self.assertEqual(a, l) - a2 = a[:] - a2[:i] = a[:i] - self.assertEqual(a2, a) - a[i:] = l[i:] - self.assertEqual(a, l) - a2 = a[:] - a2[i:] = a[i:] - self.assertEqual(a2, a) - for j in range(-3, 4): - a[i:j] = l[i:j] - self.assertEqual(a, l) - a2 = a[:] - a2[i:j] = a[i:j] - self.assertEqual(a2, a) - - aa2 = a2[:] - aa2[:0] = [-2, -1] - self.assertEqual(aa2, [-2, -1, 0, 1]) - aa2[0:] = [] - self.assertEqual(aa2, []) - - a = self.type2test([1, 2, 3, 4, 5]) - a[:-1] = a - self.assertEqual(a, self.type2test([1, 2, 3, 4, 5, 5])) - a = self.type2test([1, 2, 3, 4, 5]) - a[1:] = a - self.assertEqual(a, self.type2test([1, 1, 2, 3, 4, 5])) - a = self.type2test([1, 2, 3, 4, 5]) - a[1:-1] = a - self.assertEqual(a, self.type2test([1, 1, 2, 3, 4, 5, 5])) - - a = self.type2test([]) - a[:] = tuple(range(10)) - self.assertEqual(a, self.type2test(range(10))) - - self.assertRaises(TypeError, a.__setslice__, 0, 1, 5) - self.assertRaises(TypeError, a.__setitem__, slice(0, 1, 5)) - - self.assertRaises(TypeError, a.__setslice__) - self.assertRaises(TypeError, a.__setitem__) +# def test_setslice(self): +# l = [0, 1] +# a = self.type2test(l) +# +# for i in range(-3, 4): +# a[:i] = l[:i] +# self.assertEqual(a, l) +# a2 = a[:] +# a2[:i] = a[:i] +# self.assertEqual(a2, a) +# a[i:] = l[i:] +# self.assertEqual(a, l) +# a2 = a[:] +# a2[i:] = a[i:] +# self.assertEqual(a2, a) +# for j in range(-3, 4): +# a[i:j] = l[i:j] +# self.assertEqual(a, l) +# a2 = a[:] +# a2[i:j] = a[i:j] +# self.assertEqual(a2, a) +# +# aa2 = a2[:] +# aa2[:0] = [-2, -1] +# self.assertEqual(aa2, [-2, -1, 0, 1]) +# aa2[0:] = [] +# self.assertEqual(aa2, []) +# +# a = self.type2test([1, 2, 3, 4, 5]) +# a[:-1] = a +# self.assertEqual(a, self.type2test([1, 2, 3, 4, 5, 5])) +# a = self.type2test([1, 2, 3, 4, 5]) +# a[1:] = a +# self.assertEqual(a, self.type2test([1, 1, 2, 3, 4, 5])) +# a = self.type2test([1, 2, 3, 4, 5]) +# a[1:-1] = a +# self.assertEqual(a, self.type2test([1, 1, 2, 3, 4, 5, 5])) +# +# a = self.type2test([]) +# a[:] = tuple(range(10)) +# self.assertEqual(a, self.type2test(range(10))) +# +# self.assertRaises(TypeError, a.__setslice__, 0, 1, 5) +# self.assertRaises(TypeError, a.__setitem__, slice(0, 1, 5)) +# +# self.assertRaises(TypeError, a.__setslice__) +# self.assertRaises(TypeError, a.__setitem__) def test_delslice(self): a = self.type2test([0, 1]) @@ -238,26 +238,26 @@ def test_append(self): self.assertRaises(TypeError, a.append) - def test_extend(self): - a1 = self.type2test([0]) - a2 = self.type2test((0, 1)) - a = a1[:] - a.extend(a2) - self.assertEqual(a, a1 + a2) - - a.extend(self.type2test([])) - self.assertEqual(a, a1 + a2) - - a.extend(a) - self.assertEqual(a, self.type2test([0, 0, 1, 0, 0, 1])) - - a = self.type2test("spam") - a.extend("eggs") - self.assertEqual(a, list("spameggs")) - - self.assertRaises(TypeError, a.extend, None) - - self.assertRaises(TypeError, a.extend) +# def test_extend(self): +# a1 = self.type2test([0]) +# a2 = self.type2test((0, 1)) +# a = a1[:] +# a.extend(a2) +# self.assertEqual(a, a1 + a2) +# +# a.extend(self.type2test([])) +# self.assertEqual(a, a1 + a2) +# +# a.extend(a) +# self.assertEqual(a, self.type2test([0, 0, 1, 0, 0, 1])) +# +# a = self.type2test("spam") +# a.extend("eggs") +# self.assertEqual(a, list("spameggs")) +# +# self.assertRaises(TypeError, a.extend, None) +# +# self.assertRaises(TypeError, a.extend) def test_insert(self): a = self.type2test([0, 1, 2]) @@ -351,62 +351,62 @@ def __eq__(self, other): self.assertRaises(BadExc, a.count, BadCmp()) - def test_index(self): - u = self.type2test([0, 1]) - self.assertEqual(u.index(0), 0) - self.assertEqual(u.index(1), 1) - self.assertRaises(ValueError, u.index, 2) - - u = self.type2test([-2, -1, 0, 0, 1, 2]) - self.assertEqual(u.count(0), 2) - self.assertEqual(u.index(0), 2) - self.assertEqual(u.index(0, 2), 2) - self.assertEqual(u.index(-2, -10), 0) - self.assertEqual(u.index(0, 3), 3) - self.assertEqual(u.index(0, 3, 4), 3) - self.assertRaises(ValueError, u.index, 2, 0, -10) - - self.assertRaises(TypeError, u.index) - - class BadExc(Exception): - pass - - class BadCmp(object): - def __eq__(self, other): - if other == 2: - raise BadExc() - return False - - a = self.type2test([0, 1, 2, 3]) - self.assertRaises(BadExc, a.index, BadCmp()) - - a = self.type2test([-2, -1, 0, 0, 1, 2]) - self.assertEqual(a.index(0), 2) - self.assertEqual(a.index(0, 2), 2) - self.assertEqual(a.index(0, -4), 2) - self.assertEqual(a.index(-2, -10), 0) - self.assertEqual(a.index(0, 3), 3) - self.assertEqual(a.index(0, -3), 3) - self.assertEqual(a.index(0, 3, 4), 3) - self.assertEqual(a.index(0, -3, -2), 3) - self.assertEqual(a.index(0, -4*sys.maxint, 4*sys.maxint), 2) - self.assertRaises(ValueError, a.index, 0, 4*sys.maxint,-4*sys.maxint) - self.assertRaises(ValueError, a.index, 2, 0, -10) - a.remove(0) - self.assertRaises(ValueError, a.index, 2, 0, 4) - self.assertEqual(a, self.type2test([-2, -1, 0, 1, 2])) - - # Test modifying the list during index's iteration - class EvilCmp(object): - def __init__(self, victim): - self.victim = victim - def __eq__(self, other): - del self.victim[:] - return False - a = self.type2test() - a[:] = [EvilCmp(a) for _ in xrange(100)] - # This used to seg fault before patch #1005778 - self.assertRaises(ValueError, a.index, None) +# def test_index(self): +# u = self.type2test([0, 1]) +# self.assertEqual(u.index(0), 0) +# self.assertEqual(u.index(1), 1) +# self.assertRaises(ValueError, u.index, 2) +# +# u = self.type2test([-2, -1, 0, 0, 1, 2]) +# self.assertEqual(u.count(0), 2) +# self.assertEqual(u.index(0), 2) +# self.assertEqual(u.index(0, 2), 2) +# self.assertEqual(u.index(-2, -10), 0) +# self.assertEqual(u.index(0, 3), 3) +# self.assertEqual(u.index(0, 3, 4), 3) +# self.assertRaises(ValueError, u.index, 2, 0, -10) +# +# self.assertRaises(TypeError, u.index) +# +# class BadExc(Exception): +# pass +# +# class BadCmp(object): +# def __eq__(self, other): +# if other == 2: +# raise BadExc() +# return False +# +# a = self.type2test([0, 1, 2, 3]) +# self.assertRaises(BadExc, a.index, BadCmp()) +# +# a = self.type2test([-2, -1, 0, 0, 1, 2]) +# self.assertEqual(a.index(0), 2) +# self.assertEqual(a.index(0, 2), 2) +# self.assertEqual(a.index(0, -4), 2) +# self.assertEqual(a.index(-2, -10), 0) +# self.assertEqual(a.index(0, 3), 3) +# self.assertEqual(a.index(0, -3), 3) +# self.assertEqual(a.index(0, 3, 4), 3) +# self.assertEqual(a.index(0, -3, -2), 3) +# self.assertEqual(a.index(0, -4*sys.maxint, 4*sys.maxint), 2) +# self.assertRaises(ValueError, a.index, 0, 4*sys.maxint,-4*sys.maxint) +# self.assertRaises(ValueError, a.index, 2, 0, -10) +# a.remove(0) +# self.assertRaises(ValueError, a.index, 2, 0, 4) +# self.assertEqual(a, self.type2test([-2, -1, 0, 1, 2])) +# +# # Test modifying the list during index's iteration +# class EvilCmp(object): +# def __init__(self, victim): +# self.victim = victim +# def __eq__(self, other): +# del self.victim[:] +# return False +# a = self.type2test() +# a[:] = [EvilCmp(a) for _ in xrange(100)] +# # This used to seg fault before patch #1005778 +# self.assertRaises(ValueError, a.index, None) def test_reverse(self): u = self.type2test([-2, -1, 0, 1, 2]) @@ -418,10 +418,10 @@ def test_reverse(self): self.assertRaises(TypeError, u.reverse, 42) - def test_sort(self): - with test_support.check_py3k_warnings( - ("the cmp argument is not supported", DeprecationWarning)): - self._test_sort() +# def test_sort(self): +# with test_support.check_py3k_warnings( +# ("the cmp argument is not supported", DeprecationWarning)): +# self._test_sort() def _test_sort(self): u = self.type2test([1, 0]) @@ -485,46 +485,46 @@ def test_imul(self): s *= 10 self.assertEqual(id(s), oldid) - def test_extendedslicing(self): - # subscript - a = self.type2test([0,1,2,3,4]) - - # deletion - del a[::2] - self.assertEqual(a, self.type2test([1,3])) - a = self.type2test(range(5)) - del a[1::2] - self.assertEqual(a, self.type2test([0,2,4])) - a = self.type2test(range(5)) - del a[1::-2] - self.assertEqual(a, self.type2test([0,2,3,4])) - a = self.type2test(range(10)) - del a[::1000] - self.assertEqual(a, self.type2test([1, 2, 3, 4, 5, 6, 7, 8, 9])) - # assignment - a = self.type2test(range(10)) - a[::2] = [-1]*5 - self.assertEqual(a, self.type2test([-1, 1, -1, 3, -1, 5, -1, 7, -1, 9])) - a = self.type2test(range(10)) - a[::-4] = [10]*3 - self.assertEqual(a, self.type2test([0, 10, 2, 3, 4, 10, 6, 7, 8 ,10])) - a = self.type2test(range(4)) - a[::-1] = a - self.assertEqual(a, self.type2test([3, 2, 1, 0])) - a = self.type2test(range(10)) - b = a[:] - c = a[:] - a[2:3] = self.type2test(["two", "elements"]) - b[slice(2,3)] = self.type2test(["two", "elements"]) - c[2:3:] = self.type2test(["two", "elements"]) - self.assertEqual(a, b) - self.assertEqual(a, c) - a = self.type2test(range(10)) - a[::2] = tuple(range(5)) - self.assertEqual(a, self.type2test([0, 1, 1, 3, 2, 5, 3, 7, 4, 9])) - # test issue7788 - a = self.type2test(range(10)) - del a[9::1<<333] +# def test_extendedslicing(self): +# # subscript +# a = self.type2test([0,1,2,3,4]) +# +# # deletion +# del a[::2] +# self.assertEqual(a, self.type2test([1,3])) +# a = self.type2test(range(5)) +# del a[1::2] +# self.assertEqual(a, self.type2test([0,2,4])) +# a = self.type2test(range(5)) +# del a[1::-2] +# self.assertEqual(a, self.type2test([0,2,3,4])) +# a = self.type2test(range(10)) +# del a[::1000] +# self.assertEqual(a, self.type2test([1, 2, 3, 4, 5, 6, 7, 8, 9])) +# # assignment +# a = self.type2test(range(10)) +# a[::2] = [-1]*5 +# self.assertEqual(a, self.type2test([-1, 1, -1, 3, -1, 5, -1, 7, -1, 9])) +# a = self.type2test(range(10)) +# a[::-4] = [10]*3 +# self.assertEqual(a, self.type2test([0, 10, 2, 3, 4, 10, 6, 7, 8 ,10])) +# a = self.type2test(range(4)) +# a[::-1] = a +# self.assertEqual(a, self.type2test([3, 2, 1, 0])) +# a = self.type2test(range(10)) +# b = a[:] +# c = a[:] +# a[2:3] = self.type2test(["two", "elements"]) +# b[slice(2,3)] = self.type2test(["two", "elements"]) +# c[2:3:] = self.type2test(["two", "elements"]) +# self.assertEqual(a, b) +# self.assertEqual(a, c) +# a = self.type2test(range(10)) +# a[::2] = tuple(range(5)) +# self.assertEqual(a, self.type2test([0, 1, 1, 3, 2, 5, 3, 7, 4, 9])) +# # test issue7788 +# a = self.type2test(range(10)) +# del a[9::1<<333] def test_constructor_exception_handling(self): # Bug #1242657 diff --git a/third_party/stdlib/test/lock_tests.py b/third_party/stdlib/test/lock_tests.py new file mode 100644 index 00000000..882112fd --- /dev/null +++ b/third_party/stdlib/test/lock_tests.py @@ -0,0 +1,583 @@ +""" +Various tests for synchronization primitives. +""" + +import sys +import time +from thread import start_new_thread, get_ident +import threading +import unittest + +from test import test_support as support + + +def _wait(): + # A crude wait/yield function not relying on synchronization primitives. + time.sleep(0.01) + +class Bunch(object): + """ + A bunch of threads. + """ + def __init__(self, f, n, wait_before_exit=False): + """ + Construct a bunch of `n` threads running the same function `f`. + If `wait_before_exit` is True, the threads won't terminate until + do_finish() is called. + """ + self.f = f + self.n = n + self.started = [] + self.finished = [] + self._can_exit = not wait_before_exit + def task(): + tid = get_ident() + self.started.append(tid) + try: + f() + finally: + self.finished.append(tid) + while not self._can_exit: + _wait() + try: + for i in range(n): + start_new_thread(task, ()) + except: + self._can_exit = True + raise + + def wait_for_started(self): + while len(self.started) < self.n: + _wait() + + def wait_for_finished(self): + while len(self.finished) < self.n: + _wait() + + def do_finish(self): + self._can_exit = True + + +class BaseTestCase(unittest.TestCase): + def setUp(self): + self._threads = support.threading_setup() + + def tearDown(self): + support.threading_cleanup(*self._threads) + support.reap_children() + + +class BaseLockTests(BaseTestCase): + """ + Tests for both recursive and non-recursive locks. + """ + + def test_constructor(self): + lock = self.locktype() + del lock + + def test_acquire_destroy(self): + lock = self.locktype() + lock.acquire() + del lock + + def test_acquire_release(self): + lock = self.locktype() + lock.acquire() + lock.release() + del lock + + def test_try_acquire(self): + lock = self.locktype() + self.assertTrue(lock.acquire(False)) + lock.release() + + def test_try_acquire_contended(self): + lock = self.locktype() + lock.acquire() + result = [] + def f(): + result.append(lock.acquire(False)) + Bunch(f, 1).wait_for_finished() + self.assertFalse(result[0]) + lock.release() + + def test_acquire_contended(self): + lock = self.locktype() + lock.acquire() + N = 5 + def f(): + lock.acquire() + lock.release() + + b = Bunch(f, N) + b.wait_for_started() + _wait() + self.assertEqual(len(b.finished), 0) + lock.release() + b.wait_for_finished() + self.assertEqual(len(b.finished), N) + + def test_with(self): + lock = self.locktype() + def f(): + lock.acquire() + lock.release() + def _with(err=None): + with lock: + if err is not None: + raise err + _with() + # Check the lock is unacquired + Bunch(f, 1).wait_for_finished() + self.assertRaises(TypeError, _with, TypeError) + # Check the lock is unacquired + Bunch(f, 1).wait_for_finished() + + def test_thread_leak(self): + # The lock shouldn't leak a Thread instance when used from a foreign + # (non-threading) thread. + lock = self.locktype() + def f(): + lock.acquire() + lock.release() + n = len(threading.enumerate()) + # We run many threads in the hope that existing threads ids won't + # be recycled. + Bunch(f, 15).wait_for_finished() + self.assertEqual(n, len(threading.enumerate())) + + +class LockTests(BaseLockTests): + """ + Tests for non-recursive, weak locks + (which can be acquired and released from different threads). + """ + def test_reacquire(self): + # Lock needs to be released before re-acquiring. + lock = self.locktype() + phase = [] + def f(): + lock.acquire() + phase.append(None) + lock.acquire() + phase.append(None) + start_new_thread(f, ()) + while len(phase) == 0: + _wait() + _wait() + self.assertEqual(len(phase), 1) + lock.release() + while len(phase) == 1: + _wait() + self.assertEqual(len(phase), 2) + + def test_different_thread(self): + # Lock can be released from a different thread. + lock = self.locktype() + lock.acquire() + def f(): + lock.release() + b = Bunch(f, 1) + b.wait_for_finished() + lock.acquire() + lock.release() + + +class RLockTests(BaseLockTests): + """ + Tests for recursive locks. + """ + def test_reacquire(self): + lock = self.locktype() + lock.acquire() + lock.acquire() + lock.release() + lock.acquire() + lock.release() + lock.release() + + def test_release_unacquired(self): + # Cannot release an unacquired lock + lock = self.locktype() + self.assertRaises(RuntimeError, lock.release) + lock.acquire() + lock.acquire() + lock.release() + lock.acquire() + lock.release() + lock.release() + self.assertRaises(RuntimeError, lock.release) + + def test_different_thread(self): + # Cannot release from a different thread + lock = self.locktype() + def f(): + lock.acquire() + b = Bunch(f, 1, True) + try: + self.assertRaises(RuntimeError, lock.release) + finally: + b.do_finish() + + def test__is_owned(self): + lock = self.locktype() + self.assertFalse(lock._is_owned()) + lock.acquire() + self.assertTrue(lock._is_owned()) + lock.acquire() + self.assertTrue(lock._is_owned()) + result = [] + def f(): + result.append(lock._is_owned()) + Bunch(f, 1).wait_for_finished() + self.assertFalse(result[0]) + lock.release() + self.assertTrue(lock._is_owned()) + lock.release() + self.assertFalse(lock._is_owned()) + + +class EventTests(BaseTestCase): + """ + Tests for Event objects. + """ + + def test_is_set(self): + evt = self.eventtype() + self.assertFalse(evt.is_set()) + evt.set() + self.assertTrue(evt.is_set()) + evt.set() + self.assertTrue(evt.is_set()) + evt.clear() + self.assertFalse(evt.is_set()) + evt.clear() + self.assertFalse(evt.is_set()) + + def _check_notify(self, evt): + # All threads get notified + N = 5 + results1 = [] + results2 = [] + def f(): + results1.append(evt.wait()) + results2.append(evt.wait()) + b = Bunch(f, N) + b.wait_for_started() + _wait() + self.assertEqual(len(results1), 0) + evt.set() + b.wait_for_finished() + self.assertEqual(results1, [True] * N) + self.assertEqual(results2, [True] * N) + + def test_notify(self): + evt = self.eventtype() + self._check_notify(evt) + # Another time, after an explicit clear() + evt.set() + evt.clear() + self._check_notify(evt) + + def test_timeout(self): + evt = self.eventtype() + results1 = [] + results2 = [] + N = 5 + def f(): + results1.append(evt.wait(0.0)) + t1 = time.time() + r = evt.wait(0.2) + t2 = time.time() + results2.append((r, t2 - t1)) + Bunch(f, N).wait_for_finished() + self.assertEqual(results1, [False] * N) + for r, dt in results2: + self.assertFalse(r) + self.assertTrue(dt >= 0.2, dt) + # The event is set + results1 = [] + results2 = [] + evt.set() + Bunch(f, N).wait_for_finished() + self.assertEqual(results1, [True] * N) + for r, dt in results2: + self.assertTrue(r) + + @unittest.skip('grumpy') + def test_reset_internal_locks(self): + evt = self.eventtype() + old_lock = evt._Event__cond._Condition__lock + evt._reset_internal_locks() + new_lock = evt._Event__cond._Condition__lock + self.assertIsNot(new_lock, old_lock) + self.assertIs(type(new_lock), type(old_lock)) + + +class ConditionTests(BaseTestCase): + """ + Tests for condition variables. + """ + + def test_acquire(self): + cond = self.condtype() + # Be default we have an RLock: the condition can be acquired multiple + # times. + cond.acquire() + cond.acquire() + cond.release() + cond.release() + lock = threading.Lock() + cond = self.condtype(lock) + cond.acquire() + self.assertFalse(lock.acquire(False)) + cond.release() + self.assertTrue(lock.acquire(False)) + self.assertFalse(cond.acquire(False)) + lock.release() + with cond: + self.assertFalse(lock.acquire(False)) + + def test_unacquired_wait(self): + cond = self.condtype() + self.assertRaises(RuntimeError, cond.wait) + + def test_unacquired_notify(self): + cond = self.condtype() + self.assertRaises(RuntimeError, cond.notify) + + def _check_notify(self, cond): + # Note that this test is sensitive to timing. If the worker threads + # don't execute in a timely fashion, the main thread may think they + # are further along then they are. The main thread therefore issues + # _wait() statements to try to make sure that it doesn't race ahead + # of the workers. + # Secondly, this test assumes that condition variables are not subject + # to spurious wakeups. The absence of spurious wakeups is an implementation + # detail of Condition Cariables in current CPython, but in general, not + # a guaranteed property of condition variables as a programming + # construct. In particular, it is possible that this can no longer + # be conveniently guaranteed should their implementation ever change. + N = 5 + ready = [] + results1 = [] + results2 = [] + phase_num = 0 + def f(): + cond.acquire() + ready.append(phase_num) + cond.wait() + cond.release() + results1.append(phase_num) + cond.acquire() + ready.append(phase_num) + cond.wait() + cond.release() + results2.append(phase_num) + b = Bunch(f, N) + b.wait_for_started() + # first wait, to ensure all workers settle into cond.wait() before + # we continue. See issues #8799 and #30727. + while len(ready) < 5: + _wait() + ready = [] + self.assertEqual(results1, []) + # Notify 3 threads at first + cond.acquire() + cond.notify(3) + _wait() + phase_num = 1 + cond.release() + while len(results1) < 3: + _wait() + self.assertEqual(results1, [1] * 3) + self.assertEqual(results2, []) + # make sure all awaken workers settle into cond.wait() + while len(ready) < 3: + _wait() + # Notify 5 threads: they might be in their first or second wait + cond.acquire() + cond.notify(5) + _wait() + phase_num = 2 + cond.release() + while len(results1) + len(results2) < 8: + _wait() + self.assertEqual(results1, [1] * 3 + [2] * 2) + self.assertEqual(results2, [2] * 3) + # make sure all workers settle into cond.wait() + while len(ready) < 5: + _wait() + # Notify all threads: they are all in their second wait + cond.acquire() + cond.notify_all() + _wait() + phase_num = 3 + cond.release() + while len(results2) < 5: + _wait() + self.assertEqual(results1, [1] * 3 + [2] * 2) + self.assertEqual(results2, [2] * 3 + [3] * 2) + b.wait_for_finished() + + def test_notify(self): + cond = self.condtype() + self._check_notify(cond) + # A second time, to check internal state is still ok. + self._check_notify(cond) + + def test_timeout(self): + cond = self.condtype() + results = [] + N = 5 + def f(): + cond.acquire() + t1 = time.time() + cond.wait(0.2) + t2 = time.time() + cond.release() + results.append(t2 - t1) + Bunch(f, N).wait_for_finished() + self.assertEqual(len(results), 5) + for dt in results: + self.assertTrue(dt >= 0.2, dt) + + +class BaseSemaphoreTests(BaseTestCase): + """ + Common tests for {bounded, unbounded} semaphore objects. + """ + + def test_constructor(self): + self.assertRaises(ValueError, self.semtype, value = -1) + self.assertRaises(ValueError, self.semtype, value = -sys.maxint) + + def test_acquire(self): + sem = self.semtype(1) + sem.acquire() + sem.release() + sem = self.semtype(2) + sem.acquire() + sem.acquire() + sem.release() + sem.release() + + def test_acquire_destroy(self): + sem = self.semtype() + sem.acquire() + del sem + + def test_acquire_contended(self): + sem = self.semtype(7) + sem.acquire() + N = 10 + results1 = [] + results2 = [] + phase_num = 0 + def f(): + sem.acquire() + results1.append(phase_num) + sem.acquire() + results2.append(phase_num) + b = Bunch(f, 10) + b.wait_for_started() + while len(results1) + len(results2) < 6: + _wait() + self.assertEqual(results1 + results2, [0] * 6) + phase_num = 1 + for i in range(7): + sem.release() + while len(results1) + len(results2) < 13: + _wait() + self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7) + phase_num = 2 + for i in range(6): + sem.release() + while len(results1) + len(results2) < 19: + _wait() + self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6) + # The semaphore is still locked + self.assertFalse(sem.acquire(False)) + # Final release, to let the last thread finish + sem.release() + b.wait_for_finished() + + def test_try_acquire(self): + sem = self.semtype(2) + self.assertTrue(sem.acquire(False)) + self.assertTrue(sem.acquire(False)) + self.assertFalse(sem.acquire(False)) + sem.release() + self.assertTrue(sem.acquire(False)) + + def test_try_acquire_contended(self): + sem = self.semtype(4) + sem.acquire() + results = [] + def f(): + results.append(sem.acquire(False)) + results.append(sem.acquire(False)) + Bunch(f, 5).wait_for_finished() + # There can be a thread switch between acquiring the semaphore and + # appending the result, therefore results will not necessarily be + # ordered. + self.assertEqual(sorted(results), [False] * 7 + [True] * 3 ) + + def test_default_value(self): + # The default initial value is 1. + sem = self.semtype() + sem.acquire() + def f(): + sem.acquire() + sem.release() + b = Bunch(f, 1) + b.wait_for_started() + _wait() + self.assertFalse(b.finished) + sem.release() + b.wait_for_finished() + + def test_with(self): + sem = self.semtype(2) + def _with(err=None): + with sem: + self.assertTrue(sem.acquire(False)) + sem.release() + with sem: + self.assertFalse(sem.acquire(False)) + if err: + raise err + _with() + self.assertTrue(sem.acquire(False)) + sem.release() + self.assertRaises(TypeError, _with, TypeError) + self.assertTrue(sem.acquire(False)) + sem.release() + +class SemaphoreTests(BaseSemaphoreTests): + """ + Tests for unbounded semaphores. + """ + + def test_release_unacquired(self): + # Unbounded releases are allowed and increment the semaphore's value + sem = self.semtype(1) + sem.release() + sem.acquire() + sem.acquire() + sem.release() + + +class BoundedSemaphoreTests(BaseSemaphoreTests): + """ + Tests for bounded semaphores. + """ + + def test_release_unacquired(self): + # Cannot go past the initial value + sem = self.semtype() + self.assertRaises(ValueError, sem.release) + sem.acquire() + sem.release() + self.assertRaises(ValueError, sem.release) diff --git a/third_party/stdlib/test/mapping_tests.py b/third_party/stdlib/test/mapping_tests.py index 27bc54c9..1c8cfd90 100644 --- a/third_party/stdlib/test/mapping_tests.py +++ b/third_party/stdlib/test/mapping_tests.py @@ -40,6 +40,7 @@ def __init__(self, *args, **kw): self.inmapping = {key:value} self.reference[key] = value + @unittest.skip('grumpy') def test_read(self): # Test for read only operations on mapping p = self._empty_mapping() @@ -172,6 +173,7 @@ def test_getitem(self): self.assertRaises(TypeError, d.__getitem__) + @unittest.skip('grumpy') def test_update(self): # mapping argument d = self._empty_mapping() diff --git a/third_party/stdlib/test/seq_tests.py b/third_party/stdlib/test/seq_tests.py index aaa2e69e..6a65f4c3 100644 --- a/third_party/stdlib/test/seq_tests.py +++ b/third_party/stdlib/test/seq_tests.py @@ -174,6 +174,7 @@ def test_getitem(self): self.assertRaises(IndexError, a.__getitem__, -3) self.assertRaises(IndexError, a.__getitem__, 3) + @unittest.skip('grumpy') def test_getslice(self): l = [0, 1, 2, 3, 4] u = self.type2test(l) @@ -258,6 +259,7 @@ def test_minmax(self): self.assertEqual(min(u), 0) self.assertEqual(max(u), 2) + @unittest.skip('grumpy') def test_addmul(self): u1 = self.type2test([0]) u2 = self.type2test([0, 1]) @@ -311,6 +313,7 @@ def __getitem__(self, key): return str(key) + '!!!' self.assertEqual(iter(T((1,2))).next(), 1) + @unittest.skip('grumpy') def test_repeat(self): for m in xrange(4): s = tuple(range(m)) @@ -367,6 +370,7 @@ def __eq__(self, other): self.assertRaises(BadExc, a.count, BadCmp()) + @unittest.skip('grumpy') def test_index(self): u = self.type2test([0, 1]) self.assertEqual(u.index(0), 0) @@ -409,6 +413,7 @@ def __eq__(self, other): self.assertRaises(ValueError, a.index, 0, 4*sys.maxint,-4*sys.maxint) self.assertRaises(ValueError, a.index, 2, 0, -10) + @unittest.skip('grumpy') def test_free_after_iterating(self): support.check_free_after_iterating(self, iter, self.type2test) support.check_free_after_iterating(self, reversed, self.type2test) diff --git a/third_party/stdlib/test/string_tests.py b/third_party/stdlib/test/string_tests.py index 7ca6d4a2..d7941271 100644 --- a/third_party/stdlib/test/string_tests.py +++ b/third_party/stdlib/test/string_tests.py @@ -94,446 +94,446 @@ def test_hash(self): hash(b) self.assertEqual(hash(a), hash(b)) - def test_capitalize(self): - self.checkequal(' hello ', ' hello ', 'capitalize') - self.checkequal('Hello ', 'Hello ','capitalize') - self.checkequal('Hello ', 'hello ','capitalize') - self.checkequal('Aaaa', 'aaaa', 'capitalize') - self.checkequal('Aaaa', 'AaAa', 'capitalize') - - self.checkraises(TypeError, 'hello', 'capitalize', 42) - - def test_count(self): - self.checkequal(3, 'aaa', 'count', 'a') - self.checkequal(0, 'aaa', 'count', 'b') - self.checkequal(3, 'aaa', 'count', 'a') - self.checkequal(0, 'aaa', 'count', 'b') - self.checkequal(3, 'aaa', 'count', 'a') - self.checkequal(0, 'aaa', 'count', 'b') - self.checkequal(0, 'aaa', 'count', 'b') - self.checkequal(2, 'aaa', 'count', 'a', 1) - self.checkequal(0, 'aaa', 'count', 'a', 10) - self.checkequal(1, 'aaa', 'count', 'a', -1) - self.checkequal(3, 'aaa', 'count', 'a', -10) - self.checkequal(1, 'aaa', 'count', 'a', 0, 1) - self.checkequal(3, 'aaa', 'count', 'a', 0, 10) - self.checkequal(2, 'aaa', 'count', 'a', 0, -1) - self.checkequal(0, 'aaa', 'count', 'a', 0, -10) - self.checkequal(3, 'aaa', 'count', '', 1) - self.checkequal(1, 'aaa', 'count', '', 3) - self.checkequal(0, 'aaa', 'count', '', 10) - self.checkequal(2, 'aaa', 'count', '', -1) - self.checkequal(4, 'aaa', 'count', '', -10) - - self.checkequal(1, '', 'count', '') - self.checkequal(0, '', 'count', '', 1, 1) - self.checkequal(0, '', 'count', '', sys.maxint, 0) - - self.checkequal(0, '', 'count', 'xx') - self.checkequal(0, '', 'count', 'xx', 1, 1) - self.checkequal(0, '', 'count', 'xx', sys.maxint, 0) - - self.checkraises(TypeError, 'hello', 'count') - self.checkraises(TypeError, 'hello', 'count', 42) - - # For a variety of combinations, - # verify that str.count() matches an equivalent function - # replacing all occurrences and then differencing the string lengths - charset = ['', 'a', 'b'] - digits = 7 - base = len(charset) - teststrings = set() - for i in xrange(base ** digits): - entry = [] - for j in xrange(digits): - i, m = divmod(i, base) - entry.append(charset[m]) - teststrings.add(''.join(entry)) - teststrings = list(teststrings) - for i in teststrings: - i = self.fixtype(i) - n = len(i) - for j in teststrings: - r1 = i.count(j) - if j: - r2, rem = divmod(n - len(i.replace(j, '')), len(j)) - else: - r2, rem = len(i)+1, 0 - if rem or r1 != r2: - self.assertEqual(rem, 0, '%s != 0 for %s' % (rem, i)) - self.assertEqual(r1, r2, '%s != %s for %s' % (r1, r2, i)) - - def test_find(self): - self.checkequal(0, 'abcdefghiabc', 'find', 'abc') - self.checkequal(9, 'abcdefghiabc', 'find', 'abc', 1) - self.checkequal(-1, 'abcdefghiabc', 'find', 'def', 4) - - self.checkequal(0, 'abc', 'find', '', 0) - self.checkequal(3, 'abc', 'find', '', 3) - self.checkequal(-1, 'abc', 'find', '', 4) - - # to check the ability to pass None as defaults - self.checkequal( 2, 'rrarrrrrrrrra', 'find', 'a') - self.checkequal(12, 'rrarrrrrrrrra', 'find', 'a', 4) - self.checkequal(-1, 'rrarrrrrrrrra', 'find', 'a', 4, 6) - self.checkequal(12, 'rrarrrrrrrrra', 'find', 'a', 4, None) - self.checkequal( 2, 'rrarrrrrrrrra', 'find', 'a', None, 6) - - self.checkraises(TypeError, 'hello', 'find') - self.checkraises(TypeError, 'hello', 'find', 42) - - self.checkequal(0, '', 'find', '') - self.checkequal(-1, '', 'find', '', 1, 1) - self.checkequal(-1, '', 'find', '', sys.maxint, 0) - - self.checkequal(-1, '', 'find', 'xx') - self.checkequal(-1, '', 'find', 'xx', 1, 1) - self.checkequal(-1, '', 'find', 'xx', sys.maxint, 0) - - # issue 7458 - self.checkequal(-1, 'ab', 'find', 'xxx', sys.maxsize + 1, 0) - - # For a variety of combinations, - # verify that str.find() matches __contains__ - # and that the found substring is really at that location - charset = ['', 'a', 'b', 'c'] - digits = 5 - base = len(charset) - teststrings = set() - for i in xrange(base ** digits): - entry = [] - for j in xrange(digits): - i, m = divmod(i, base) - entry.append(charset[m]) - teststrings.add(''.join(entry)) - teststrings = list(teststrings) - for i in teststrings: - i = self.fixtype(i) - for j in teststrings: - loc = i.find(j) - r1 = (loc != -1) - r2 = j in i - self.assertEqual(r1, r2) - if loc != -1: - self.assertEqual(i[loc:loc+len(j)], j) - - def test_rfind(self): - self.checkequal(9, 'abcdefghiabc', 'rfind', 'abc') - self.checkequal(12, 'abcdefghiabc', 'rfind', '') - self.checkequal(0, 'abcdefghiabc', 'rfind', 'abcd') - self.checkequal(-1, 'abcdefghiabc', 'rfind', 'abcz') - - self.checkequal(3, 'abc', 'rfind', '', 0) - self.checkequal(3, 'abc', 'rfind', '', 3) - self.checkequal(-1, 'abc', 'rfind', '', 4) - - # to check the ability to pass None as defaults - self.checkequal(12, 'rrarrrrrrrrra', 'rfind', 'a') - self.checkequal(12, 'rrarrrrrrrrra', 'rfind', 'a', 4) - self.checkequal(-1, 'rrarrrrrrrrra', 'rfind', 'a', 4, 6) - self.checkequal(12, 'rrarrrrrrrrra', 'rfind', 'a', 4, None) - self.checkequal( 2, 'rrarrrrrrrrra', 'rfind', 'a', None, 6) - - self.checkraises(TypeError, 'hello', 'rfind') - self.checkraises(TypeError, 'hello', 'rfind', 42) - - # For a variety of combinations, - # verify that str.rfind() matches __contains__ - # and that the found substring is really at that location - charset = ['', 'a', 'b', 'c'] - digits = 5 - base = len(charset) - teststrings = set() - for i in xrange(base ** digits): - entry = [] - for j in xrange(digits): - i, m = divmod(i, base) - entry.append(charset[m]) - teststrings.add(''.join(entry)) - teststrings = list(teststrings) - for i in teststrings: - i = self.fixtype(i) - for j in teststrings: - loc = i.rfind(j) - r1 = (loc != -1) - r2 = j in i - self.assertEqual(r1, r2) - if loc != -1: - self.assertEqual(i[loc:loc+len(j)], self.fixtype(j)) - - # issue 7458 - self.checkequal(-1, 'ab', 'rfind', 'xxx', sys.maxsize + 1, 0) - - def test_index(self): - self.checkequal(0, 'abcdefghiabc', 'index', '') - self.checkequal(3, 'abcdefghiabc', 'index', 'def') - self.checkequal(0, 'abcdefghiabc', 'index', 'abc') - self.checkequal(9, 'abcdefghiabc', 'index', 'abc', 1) - - self.checkraises(ValueError, 'abcdefghiabc', 'index', 'hib') - self.checkraises(ValueError, 'abcdefghiab', 'index', 'abc', 1) - self.checkraises(ValueError, 'abcdefghi', 'index', 'ghi', 8) - self.checkraises(ValueError, 'abcdefghi', 'index', 'ghi', -1) - - # to check the ability to pass None as defaults - self.checkequal( 2, 'rrarrrrrrrrra', 'index', 'a') - self.checkequal(12, 'rrarrrrrrrrra', 'index', 'a', 4) - self.checkraises(ValueError, 'rrarrrrrrrrra', 'index', 'a', 4, 6) - self.checkequal(12, 'rrarrrrrrrrra', 'index', 'a', 4, None) - self.checkequal( 2, 'rrarrrrrrrrra', 'index', 'a', None, 6) - - self.checkraises(TypeError, 'hello', 'index') - self.checkraises(TypeError, 'hello', 'index', 42) - - def test_rindex(self): - self.checkequal(12, 'abcdefghiabc', 'rindex', '') - self.checkequal(3, 'abcdefghiabc', 'rindex', 'def') - self.checkequal(9, 'abcdefghiabc', 'rindex', 'abc') - self.checkequal(0, 'abcdefghiabc', 'rindex', 'abc', 0, -1) - - self.checkraises(ValueError, 'abcdefghiabc', 'rindex', 'hib') - self.checkraises(ValueError, 'defghiabc', 'rindex', 'def', 1) - self.checkraises(ValueError, 'defghiabc', 'rindex', 'abc', 0, -1) - self.checkraises(ValueError, 'abcdefghi', 'rindex', 'ghi', 0, 8) - self.checkraises(ValueError, 'abcdefghi', 'rindex', 'ghi', 0, -1) - - # to check the ability to pass None as defaults - self.checkequal(12, 'rrarrrrrrrrra', 'rindex', 'a') - self.checkequal(12, 'rrarrrrrrrrra', 'rindex', 'a', 4) - self.checkraises(ValueError, 'rrarrrrrrrrra', 'rindex', 'a', 4, 6) - self.checkequal(12, 'rrarrrrrrrrra', 'rindex', 'a', 4, None) - self.checkequal( 2, 'rrarrrrrrrrra', 'rindex', 'a', None, 6) - - self.checkraises(TypeError, 'hello', 'rindex') - self.checkraises(TypeError, 'hello', 'rindex', 42) - - def test_lower(self): - self.checkequal('hello', 'HeLLo', 'lower') - self.checkequal('hello', 'hello', 'lower') - self.checkraises(TypeError, 'hello', 'lower', 42) - - def test_upper(self): - self.checkequal('HELLO', 'HeLLo', 'upper') - self.checkequal('HELLO', 'HELLO', 'upper') - self.checkraises(TypeError, 'hello', 'upper', 42) - - def test_expandtabs(self): - self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs') - self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs', 8) - self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs', 4) - self.checkequal('abc\r\nab def\ng hi', 'abc\r\nab\tdef\ng\thi', 'expandtabs', 4) - self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs') - self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs', 8) - self.checkequal('abc\r\nab\r\ndef\ng\r\nhi', 'abc\r\nab\r\ndef\ng\r\nhi', 'expandtabs', 4) - self.checkequal(' a\n b', ' \ta\n\tb', 'expandtabs', 1) - - self.checkraises(TypeError, 'hello', 'expandtabs', 42, 42) - # This test is only valid when sizeof(int) == sizeof(void*) == 4. - if sys.maxint < (1 << 32) and struct.calcsize('P') == 4: - self.checkraises(OverflowError, - '\ta\n\tb', 'expandtabs', sys.maxint) - - def test_split(self): - self.checkequal(['this', 'is', 'the', 'split', 'function'], - 'this is the split function', 'split') - - # by whitespace - self.checkequal(['a', 'b', 'c', 'd'], 'a b c d ', 'split') - self.checkequal(['a', 'b c d'], 'a b c d', 'split', None, 1) - self.checkequal(['a', 'b', 'c d'], 'a b c d', 'split', None, 2) - self.checkequal(['a', 'b', 'c', 'd'], 'a b c d', 'split', None, 3) - self.checkequal(['a', 'b', 'c', 'd'], 'a b c d', 'split', None, 4) - self.checkequal(['a', 'b', 'c', 'd'], 'a b c d', 'split', None, - sys.maxint-1) - self.checkequal(['a b c d'], 'a b c d', 'split', None, 0) - self.checkequal(['a b c d'], ' a b c d', 'split', None, 0) - self.checkequal(['a', 'b', 'c d'], 'a b c d', 'split', None, 2) - - self.checkequal([], ' ', 'split') - self.checkequal(['a'], ' a ', 'split') - self.checkequal(['a', 'b'], ' a b ', 'split') - self.checkequal(['a', 'b '], ' a b ', 'split', None, 1) - self.checkequal(['a b c '], ' a b c ', 'split', None, 0) - self.checkequal(['a', 'b c '], ' a b c ', 'split', None, 1) - self.checkequal(['a', 'b', 'c '], ' a b c ', 'split', None, 2) - self.checkequal(['a', 'b', 'c'], ' a b c ', 'split', None, 3) - self.checkequal(['a', 'b'], '\n\ta \t\r b \v ', 'split') - aaa = ' a '*20 - self.checkequal(['a']*20, aaa, 'split') - self.checkequal(['a'] + [aaa[4:]], aaa, 'split', None, 1) - self.checkequal(['a']*19 + ['a '], aaa, 'split', None, 19) - - for b in ('arf\tbarf', 'arf\nbarf', 'arf\rbarf', - 'arf\fbarf', 'arf\vbarf'): - self.checkequal(['arf', 'barf'], b, 'split') - self.checkequal(['arf', 'barf'], b, 'split', None) - self.checkequal(['arf', 'barf'], b, 'split', None, 2) - - # by a char - self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'split', '|') - self.checkequal(['a|b|c|d'], 'a|b|c|d', 'split', '|', 0) - self.checkequal(['a', 'b|c|d'], 'a|b|c|d', 'split', '|', 1) - self.checkequal(['a', 'b', 'c|d'], 'a|b|c|d', 'split', '|', 2) - self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'split', '|', 3) - self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'split', '|', 4) - self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'split', '|', - sys.maxint-2) - self.checkequal(['a|b|c|d'], 'a|b|c|d', 'split', '|', 0) - self.checkequal(['a', '', 'b||c||d'], 'a||b||c||d', 'split', '|', 2) - self.checkequal(['abcd'], 'abcd', 'split', '|') - self.checkequal([''], '', 'split', '|') - self.checkequal(['endcase ', ''], 'endcase |', 'split', '|') - self.checkequal(['', ' startcase'], '| startcase', 'split', '|') - self.checkequal(['', 'bothcase', ''], '|bothcase|', 'split', '|') - self.checkequal(['a', '', 'b\x00c\x00d'], 'a\x00\x00b\x00c\x00d', 'split', '\x00', 2) - - self.checkequal(['a']*20, ('a|'*20)[:-1], 'split', '|') - self.checkequal(['a']*15 +['a|a|a|a|a'], - ('a|'*20)[:-1], 'split', '|', 15) - - # by string - self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'split', '//') - self.checkequal(['a', 'b//c//d'], 'a//b//c//d', 'split', '//', 1) - self.checkequal(['a', 'b', 'c//d'], 'a//b//c//d', 'split', '//', 2) - self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'split', '//', 3) - self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'split', '//', 4) - self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'split', '//', - sys.maxint-10) - self.checkequal(['a//b//c//d'], 'a//b//c//d', 'split', '//', 0) - self.checkequal(['a', '', 'b////c////d'], 'a////b////c////d', 'split', '//', 2) - self.checkequal(['endcase ', ''], 'endcase test', 'split', 'test') - self.checkequal(['', ' begincase'], 'test begincase', 'split', 'test') - self.checkequal(['', ' bothcase ', ''], 'test bothcase test', - 'split', 'test') - self.checkequal(['a', 'bc'], 'abbbc', 'split', 'bb') - self.checkequal(['', ''], 'aaa', 'split', 'aaa') - self.checkequal(['aaa'], 'aaa', 'split', 'aaa', 0) - self.checkequal(['ab', 'ab'], 'abbaab', 'split', 'ba') - self.checkequal(['aaaa'], 'aaaa', 'split', 'aab') - self.checkequal([''], '', 'split', 'aaa') - self.checkequal(['aa'], 'aa', 'split', 'aaa') - self.checkequal(['A', 'bobb'], 'Abbobbbobb', 'split', 'bbobb') - self.checkequal(['A', 'B', ''], 'AbbobbBbbobb', 'split', 'bbobb') - - self.checkequal(['a']*20, ('aBLAH'*20)[:-4], 'split', 'BLAH') - self.checkequal(['a']*20, ('aBLAH'*20)[:-4], 'split', 'BLAH', 19) - self.checkequal(['a']*18 + ['aBLAHa'], ('aBLAH'*20)[:-4], - 'split', 'BLAH', 18) - - # mixed use of str and unicode - if self.type2test is not bytearray: - result = [u'a', u'b', u'c d'] - self.checkequal(result, 'a b c d', 'split', u' ', 2) - - # argument type - self.checkraises(TypeError, 'hello', 'split', 42, 42, 42) - - # null case - self.checkraises(ValueError, 'hello', 'split', '') - self.checkraises(ValueError, 'hello', 'split', '', 0) - - def test_rsplit(self): - self.checkequal(['this', 'is', 'the', 'rsplit', 'function'], - 'this is the rsplit function', 'rsplit') - - # by whitespace - self.checkequal(['a', 'b', 'c', 'd'], 'a b c d ', 'rsplit') - self.checkequal(['a b c', 'd'], 'a b c d', 'rsplit', None, 1) - self.checkequal(['a b', 'c', 'd'], 'a b c d', 'rsplit', None, 2) - self.checkequal(['a', 'b', 'c', 'd'], 'a b c d', 'rsplit', None, 3) - self.checkequal(['a', 'b', 'c', 'd'], 'a b c d', 'rsplit', None, 4) - self.checkequal(['a', 'b', 'c', 'd'], 'a b c d', 'rsplit', None, - sys.maxint-20) - self.checkequal(['a b c d'], 'a b c d', 'rsplit', None, 0) - self.checkequal(['a b c d'], 'a b c d ', 'rsplit', None, 0) - self.checkequal(['a b', 'c', 'd'], 'a b c d', 'rsplit', None, 2) - - self.checkequal([], ' ', 'rsplit') - self.checkequal(['a'], ' a ', 'rsplit') - self.checkequal(['a', 'b'], ' a b ', 'rsplit') - self.checkequal([' a', 'b'], ' a b ', 'rsplit', None, 1) - self.checkequal([' a b c'], ' a b c ', 'rsplit', - None, 0) - self.checkequal([' a b','c'], ' a b c ', 'rsplit', - None, 1) - self.checkequal([' a', 'b', 'c'], ' a b c ', 'rsplit', - None, 2) - self.checkequal(['a', 'b', 'c'], ' a b c ', 'rsplit', - None, 3) - self.checkequal(['a', 'b'], '\n\ta \t\r b \v ', 'rsplit', None, 88) - aaa = ' a '*20 - self.checkequal(['a']*20, aaa, 'rsplit') - self.checkequal([aaa[:-4]] + ['a'], aaa, 'rsplit', None, 1) - self.checkequal([' a a'] + ['a']*18, aaa, 'rsplit', None, 18) - - for b in ('arf\tbarf', 'arf\nbarf', 'arf\rbarf', - 'arf\fbarf', 'arf\vbarf'): - self.checkequal(['arf', 'barf'], b, 'rsplit') - self.checkequal(['arf', 'barf'], b, 'rsplit', None) - self.checkequal(['arf', 'barf'], b, 'rsplit', None, 2) - - # by a char - self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'rsplit', '|') - self.checkequal(['a|b|c', 'd'], 'a|b|c|d', 'rsplit', '|', 1) - self.checkequal(['a|b', 'c', 'd'], 'a|b|c|d', 'rsplit', '|', 2) - self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'rsplit', '|', 3) - self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'rsplit', '|', 4) - self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'rsplit', '|', - sys.maxint-100) - self.checkequal(['a|b|c|d'], 'a|b|c|d', 'rsplit', '|', 0) - self.checkequal(['a||b||c', '', 'd'], 'a||b||c||d', 'rsplit', '|', 2) - self.checkequal(['abcd'], 'abcd', 'rsplit', '|') - self.checkequal([''], '', 'rsplit', '|') - self.checkequal(['', ' begincase'], '| begincase', 'rsplit', '|') - self.checkequal(['endcase ', ''], 'endcase |', 'rsplit', '|') - self.checkequal(['', 'bothcase', ''], '|bothcase|', 'rsplit', '|') - - self.checkequal(['a\x00\x00b', 'c', 'd'], 'a\x00\x00b\x00c\x00d', 'rsplit', '\x00', 2) - - self.checkequal(['a']*20, ('a|'*20)[:-1], 'rsplit', '|') - self.checkequal(['a|a|a|a|a']+['a']*15, - ('a|'*20)[:-1], 'rsplit', '|', 15) - - # by string - self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'rsplit', '//') - self.checkequal(['a//b//c', 'd'], 'a//b//c//d', 'rsplit', '//', 1) - self.checkequal(['a//b', 'c', 'd'], 'a//b//c//d', 'rsplit', '//', 2) - self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'rsplit', '//', 3) - self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'rsplit', '//', 4) - self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'rsplit', '//', - sys.maxint-5) - self.checkequal(['a//b//c//d'], 'a//b//c//d', 'rsplit', '//', 0) - self.checkequal(['a////b////c', '', 'd'], 'a////b////c////d', 'rsplit', '//', 2) - self.checkequal(['', ' begincase'], 'test begincase', 'rsplit', 'test') - self.checkequal(['endcase ', ''], 'endcase test', 'rsplit', 'test') - self.checkequal(['', ' bothcase ', ''], 'test bothcase test', - 'rsplit', 'test') - self.checkequal(['ab', 'c'], 'abbbc', 'rsplit', 'bb') - self.checkequal(['', ''], 'aaa', 'rsplit', 'aaa') - self.checkequal(['aaa'], 'aaa', 'rsplit', 'aaa', 0) - self.checkequal(['ab', 'ab'], 'abbaab', 'rsplit', 'ba') - self.checkequal(['aaaa'], 'aaaa', 'rsplit', 'aab') - self.checkequal([''], '', 'rsplit', 'aaa') - self.checkequal(['aa'], 'aa', 'rsplit', 'aaa') - self.checkequal(['bbob', 'A'], 'bbobbbobbA', 'rsplit', 'bbobb') - self.checkequal(['', 'B', 'A'], 'bbobbBbbobbA', 'rsplit', 'bbobb') - - self.checkequal(['a']*20, ('aBLAH'*20)[:-4], 'rsplit', 'BLAH') - self.checkequal(['a']*20, ('aBLAH'*20)[:-4], 'rsplit', 'BLAH', 19) - self.checkequal(['aBLAHa'] + ['a']*18, ('aBLAH'*20)[:-4], - 'rsplit', 'BLAH', 18) - - # mixed use of str and unicode - if self.type2test is not bytearray: - result = [u'a b', u'c', u'd'] - self.checkequal(result, 'a b c d', 'rsplit', u' ', 2) - - # argument type - self.checkraises(TypeError, 'hello', 'rsplit', 42, 42, 42) - - # null case - self.checkraises(ValueError, 'hello', 'rsplit', '') - self.checkraises(ValueError, 'hello', 'rsplit', '', 0) +# def test_capitalize(self): +# self.checkequal(' hello ', ' hello ', 'capitalize') +# self.checkequal('Hello ', 'Hello ','capitalize') +# self.checkequal('Hello ', 'hello ','capitalize') +# self.checkequal('Aaaa', 'aaaa', 'capitalize') +# self.checkequal('Aaaa', 'AaAa', 'capitalize') +# +# self.checkraises(TypeError, 'hello', 'capitalize', 42) + +# def test_count(self): +# self.checkequal(3, 'aaa', 'count', 'a') +# self.checkequal(0, 'aaa', 'count', 'b') +# self.checkequal(3, 'aaa', 'count', 'a') +# self.checkequal(0, 'aaa', 'count', 'b') +# self.checkequal(3, 'aaa', 'count', 'a') +# self.checkequal(0, 'aaa', 'count', 'b') +# self.checkequal(0, 'aaa', 'count', 'b') +# self.checkequal(2, 'aaa', 'count', 'a', 1) +# self.checkequal(0, 'aaa', 'count', 'a', 10) +# self.checkequal(1, 'aaa', 'count', 'a', -1) +# self.checkequal(3, 'aaa', 'count', 'a', -10) +# self.checkequal(1, 'aaa', 'count', 'a', 0, 1) +# self.checkequal(3, 'aaa', 'count', 'a', 0, 10) +# self.checkequal(2, 'aaa', 'count', 'a', 0, -1) +# self.checkequal(0, 'aaa', 'count', 'a', 0, -10) +# self.checkequal(3, 'aaa', 'count', '', 1) +# self.checkequal(1, 'aaa', 'count', '', 3) +# self.checkequal(0, 'aaa', 'count', '', 10) +# self.checkequal(2, 'aaa', 'count', '', -1) +# self.checkequal(4, 'aaa', 'count', '', -10) +# +# self.checkequal(1, '', 'count', '') +# self.checkequal(0, '', 'count', '', 1, 1) +# self.checkequal(0, '', 'count', '', sys.maxint, 0) +# +# self.checkequal(0, '', 'count', 'xx') +# self.checkequal(0, '', 'count', 'xx', 1, 1) +# self.checkequal(0, '', 'count', 'xx', sys.maxint, 0) +# +# self.checkraises(TypeError, 'hello', 'count') +# self.checkraises(TypeError, 'hello', 'count', 42) +# +# # For a variety of combinations, +# # verify that str.count() matches an equivalent function +# # replacing all occurrences and then differencing the string lengths +# charset = ['', 'a', 'b'] +# digits = 7 +# base = len(charset) +# teststrings = set() +# for i in xrange(base ** digits): +# entry = [] +# for j in xrange(digits): +# i, m = divmod(i, base) +# entry.append(charset[m]) +# teststrings.add(''.join(entry)) +# teststrings = list(teststrings) +# for i in teststrings: +# i = self.fixtype(i) +# n = len(i) +# for j in teststrings: +# r1 = i.count(j) +# if j: +# r2, rem = divmod(n - len(i.replace(j, '')), len(j)) +# else: +# r2, rem = len(i)+1, 0 +# if rem or r1 != r2: +# self.assertEqual(rem, 0, '%s != 0 for %s' % (rem, i)) +# self.assertEqual(r1, r2, '%s != %s for %s' % (r1, r2, i)) + +# def test_find(self): +# self.checkequal(0, 'abcdefghiabc', 'find', 'abc') +# self.checkequal(9, 'abcdefghiabc', 'find', 'abc', 1) +# self.checkequal(-1, 'abcdefghiabc', 'find', 'def', 4) +# +# self.checkequal(0, 'abc', 'find', '', 0) +# self.checkequal(3, 'abc', 'find', '', 3) +# self.checkequal(-1, 'abc', 'find', '', 4) +# +# # to check the ability to pass None as defaults +# self.checkequal( 2, 'rrarrrrrrrrra', 'find', 'a') +# self.checkequal(12, 'rrarrrrrrrrra', 'find', 'a', 4) +# self.checkequal(-1, 'rrarrrrrrrrra', 'find', 'a', 4, 6) +# self.checkequal(12, 'rrarrrrrrrrra', 'find', 'a', 4, None) +# self.checkequal( 2, 'rrarrrrrrrrra', 'find', 'a', None, 6) +# +# self.checkraises(TypeError, 'hello', 'find') +# self.checkraises(TypeError, 'hello', 'find', 42) +# +# self.checkequal(0, '', 'find', '') +# self.checkequal(-1, '', 'find', '', 1, 1) +# self.checkequal(-1, '', 'find', '', sys.maxint, 0) +# +# self.checkequal(-1, '', 'find', 'xx') +# self.checkequal(-1, '', 'find', 'xx', 1, 1) +# self.checkequal(-1, '', 'find', 'xx', sys.maxint, 0) +# +# # issue 7458 +# self.checkequal(-1, 'ab', 'find', 'xxx', sys.maxsize + 1, 0) +# +# # For a variety of combinations, +# # verify that str.find() matches __contains__ +# # and that the found substring is really at that location +# charset = ['', 'a', 'b', 'c'] +# digits = 5 +# base = len(charset) +# teststrings = set() +# for i in xrange(base ** digits): +# entry = [] +# for j in xrange(digits): +# i, m = divmod(i, base) +# entry.append(charset[m]) +# teststrings.add(''.join(entry)) +# teststrings = list(teststrings) +# for i in teststrings: +# i = self.fixtype(i) +# for j in teststrings: +# loc = i.find(j) +# r1 = (loc != -1) +# r2 = j in i +# self.assertEqual(r1, r2) +# if loc != -1: +# self.assertEqual(i[loc:loc+len(j)], j) + +# def test_rfind(self): +# self.checkequal(9, 'abcdefghiabc', 'rfind', 'abc') +# self.checkequal(12, 'abcdefghiabc', 'rfind', '') +# self.checkequal(0, 'abcdefghiabc', 'rfind', 'abcd') +# self.checkequal(-1, 'abcdefghiabc', 'rfind', 'abcz') +# +# self.checkequal(3, 'abc', 'rfind', '', 0) +# self.checkequal(3, 'abc', 'rfind', '', 3) +# self.checkequal(-1, 'abc', 'rfind', '', 4) +# +# # to check the ability to pass None as defaults +# self.checkequal(12, 'rrarrrrrrrrra', 'rfind', 'a') +# self.checkequal(12, 'rrarrrrrrrrra', 'rfind', 'a', 4) +# self.checkequal(-1, 'rrarrrrrrrrra', 'rfind', 'a', 4, 6) +# self.checkequal(12, 'rrarrrrrrrrra', 'rfind', 'a', 4, None) +# self.checkequal( 2, 'rrarrrrrrrrra', 'rfind', 'a', None, 6) +# +# self.checkraises(TypeError, 'hello', 'rfind') +# self.checkraises(TypeError, 'hello', 'rfind', 42) +# +# # For a variety of combinations, +# # verify that str.rfind() matches __contains__ +# # and that the found substring is really at that location +# charset = ['', 'a', 'b', 'c'] +# digits = 5 +# base = len(charset) +# teststrings = set() +# for i in xrange(base ** digits): +# entry = [] +# for j in xrange(digits): +# i, m = divmod(i, base) +# entry.append(charset[m]) +# teststrings.add(''.join(entry)) +# teststrings = list(teststrings) +# for i in teststrings: +# i = self.fixtype(i) +# for j in teststrings: +# loc = i.rfind(j) +# r1 = (loc != -1) +# r2 = j in i +# self.assertEqual(r1, r2) +# if loc != -1: +# self.assertEqual(i[loc:loc+len(j)], self.fixtype(j)) +# +# # issue 7458 +# self.checkequal(-1, 'ab', 'rfind', 'xxx', sys.maxsize + 1, 0) + +# def test_index(self): +# self.checkequal(0, 'abcdefghiabc', 'index', '') +# self.checkequal(3, 'abcdefghiabc', 'index', 'def') +# self.checkequal(0, 'abcdefghiabc', 'index', 'abc') +# self.checkequal(9, 'abcdefghiabc', 'index', 'abc', 1) +# +# self.checkraises(ValueError, 'abcdefghiabc', 'index', 'hib') +# self.checkraises(ValueError, 'abcdefghiab', 'index', 'abc', 1) +# self.checkraises(ValueError, 'abcdefghi', 'index', 'ghi', 8) +# self.checkraises(ValueError, 'abcdefghi', 'index', 'ghi', -1) +# +# # to check the ability to pass None as defaults +# self.checkequal( 2, 'rrarrrrrrrrra', 'index', 'a') +# self.checkequal(12, 'rrarrrrrrrrra', 'index', 'a', 4) +# self.checkraises(ValueError, 'rrarrrrrrrrra', 'index', 'a', 4, 6) +# self.checkequal(12, 'rrarrrrrrrrra', 'index', 'a', 4, None) +# self.checkequal( 2, 'rrarrrrrrrrra', 'index', 'a', None, 6) +# +# self.checkraises(TypeError, 'hello', 'index') +# self.checkraises(TypeError, 'hello', 'index', 42) + +# def test_rindex(self): +# self.checkequal(12, 'abcdefghiabc', 'rindex', '') +# self.checkequal(3, 'abcdefghiabc', 'rindex', 'def') +# self.checkequal(9, 'abcdefghiabc', 'rindex', 'abc') +# self.checkequal(0, 'abcdefghiabc', 'rindex', 'abc', 0, -1) +# +# self.checkraises(ValueError, 'abcdefghiabc', 'rindex', 'hib') +# self.checkraises(ValueError, 'defghiabc', 'rindex', 'def', 1) +# self.checkraises(ValueError, 'defghiabc', 'rindex', 'abc', 0, -1) +# self.checkraises(ValueError, 'abcdefghi', 'rindex', 'ghi', 0, 8) +# self.checkraises(ValueError, 'abcdefghi', 'rindex', 'ghi', 0, -1) +# +# # to check the ability to pass None as defaults +# self.checkequal(12, 'rrarrrrrrrrra', 'rindex', 'a') +# self.checkequal(12, 'rrarrrrrrrrra', 'rindex', 'a', 4) +# self.checkraises(ValueError, 'rrarrrrrrrrra', 'rindex', 'a', 4, 6) +# self.checkequal(12, 'rrarrrrrrrrra', 'rindex', 'a', 4, None) +# self.checkequal( 2, 'rrarrrrrrrrra', 'rindex', 'a', None, 6) +# +# self.checkraises(TypeError, 'hello', 'rindex') +# self.checkraises(TypeError, 'hello', 'rindex', 42) + +# def test_lower(self): +# self.checkequal('hello', 'HeLLo', 'lower') +# self.checkequal('hello', 'hello', 'lower') +# self.checkraises(TypeError, 'hello', 'lower', 42) + +# def test_upper(self): +# self.checkequal('HELLO', 'HeLLo', 'upper') +# self.checkequal('HELLO', 'HELLO', 'upper') +# self.checkraises(TypeError, 'hello', 'upper', 42) + +# def test_expandtabs(self): +# self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs') +# self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs', 8) +# self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs', 4) +# self.checkequal('abc\r\nab def\ng hi', 'abc\r\nab\tdef\ng\thi', 'expandtabs', 4) +# self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs') +# self.checkequal('abc\rab def\ng hi', 'abc\rab\tdef\ng\thi', 'expandtabs', 8) +# self.checkequal('abc\r\nab\r\ndef\ng\r\nhi', 'abc\r\nab\r\ndef\ng\r\nhi', 'expandtabs', 4) +# self.checkequal(' a\n b', ' \ta\n\tb', 'expandtabs', 1) +# +# self.checkraises(TypeError, 'hello', 'expandtabs', 42, 42) +# # This test is only valid when sizeof(int) == sizeof(void*) == 4. +# if sys.maxint < (1 << 32) and struct.calcsize('P') == 4: +# self.checkraises(OverflowError, +# '\ta\n\tb', 'expandtabs', sys.maxint) + +# def test_split(self): +# self.checkequal(['this', 'is', 'the', 'split', 'function'], +# 'this is the split function', 'split') +# +# # by whitespace +# self.checkequal(['a', 'b', 'c', 'd'], 'a b c d ', 'split') +# self.checkequal(['a', 'b c d'], 'a b c d', 'split', None, 1) +# self.checkequal(['a', 'b', 'c d'], 'a b c d', 'split', None, 2) +# self.checkequal(['a', 'b', 'c', 'd'], 'a b c d', 'split', None, 3) +# self.checkequal(['a', 'b', 'c', 'd'], 'a b c d', 'split', None, 4) +# self.checkequal(['a', 'b', 'c', 'd'], 'a b c d', 'split', None, +# sys.maxint-1) +# self.checkequal(['a b c d'], 'a b c d', 'split', None, 0) +# self.checkequal(['a b c d'], ' a b c d', 'split', None, 0) +# self.checkequal(['a', 'b', 'c d'], 'a b c d', 'split', None, 2) +# +# self.checkequal([], ' ', 'split') +# self.checkequal(['a'], ' a ', 'split') +# self.checkequal(['a', 'b'], ' a b ', 'split') +# self.checkequal(['a', 'b '], ' a b ', 'split', None, 1) +# self.checkequal(['a b c '], ' a b c ', 'split', None, 0) +# self.checkequal(['a', 'b c '], ' a b c ', 'split', None, 1) +# self.checkequal(['a', 'b', 'c '], ' a b c ', 'split', None, 2) +# self.checkequal(['a', 'b', 'c'], ' a b c ', 'split', None, 3) +# self.checkequal(['a', 'b'], '\n\ta \t\r b \v ', 'split') +# aaa = ' a '*20 +# self.checkequal(['a']*20, aaa, 'split') +# self.checkequal(['a'] + [aaa[4:]], aaa, 'split', None, 1) +# self.checkequal(['a']*19 + ['a '], aaa, 'split', None, 19) +# +# for b in ('arf\tbarf', 'arf\nbarf', 'arf\rbarf', +# 'arf\fbarf', 'arf\vbarf'): +# self.checkequal(['arf', 'barf'], b, 'split') +# self.checkequal(['arf', 'barf'], b, 'split', None) +# self.checkequal(['arf', 'barf'], b, 'split', None, 2) +# +# # by a char +# self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'split', '|') +# self.checkequal(['a|b|c|d'], 'a|b|c|d', 'split', '|', 0) +# self.checkequal(['a', 'b|c|d'], 'a|b|c|d', 'split', '|', 1) +# self.checkequal(['a', 'b', 'c|d'], 'a|b|c|d', 'split', '|', 2) +# self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'split', '|', 3) +# self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'split', '|', 4) +# self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'split', '|', +# sys.maxint-2) +# self.checkequal(['a|b|c|d'], 'a|b|c|d', 'split', '|', 0) +# self.checkequal(['a', '', 'b||c||d'], 'a||b||c||d', 'split', '|', 2) +# self.checkequal(['abcd'], 'abcd', 'split', '|') +# self.checkequal([''], '', 'split', '|') +# self.checkequal(['endcase ', ''], 'endcase |', 'split', '|') +# self.checkequal(['', ' startcase'], '| startcase', 'split', '|') +# self.checkequal(['', 'bothcase', ''], '|bothcase|', 'split', '|') +# self.checkequal(['a', '', 'b\x00c\x00d'], 'a\x00\x00b\x00c\x00d', 'split', '\x00', 2) +# +# self.checkequal(['a']*20, ('a|'*20)[:-1], 'split', '|') +# self.checkequal(['a']*15 +['a|a|a|a|a'], +# ('a|'*20)[:-1], 'split', '|', 15) +# +# # by string +# self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'split', '//') +# self.checkequal(['a', 'b//c//d'], 'a//b//c//d', 'split', '//', 1) +# self.checkequal(['a', 'b', 'c//d'], 'a//b//c//d', 'split', '//', 2) +# self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'split', '//', 3) +# self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'split', '//', 4) +# self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'split', '//', +# sys.maxint-10) +# self.checkequal(['a//b//c//d'], 'a//b//c//d', 'split', '//', 0) +# self.checkequal(['a', '', 'b////c////d'], 'a////b////c////d', 'split', '//', 2) +# self.checkequal(['endcase ', ''], 'endcase test', 'split', 'test') +# self.checkequal(['', ' begincase'], 'test begincase', 'split', 'test') +# self.checkequal(['', ' bothcase ', ''], 'test bothcase test', +# 'split', 'test') +# self.checkequal(['a', 'bc'], 'abbbc', 'split', 'bb') +# self.checkequal(['', ''], 'aaa', 'split', 'aaa') +# self.checkequal(['aaa'], 'aaa', 'split', 'aaa', 0) +# self.checkequal(['ab', 'ab'], 'abbaab', 'split', 'ba') +# self.checkequal(['aaaa'], 'aaaa', 'split', 'aab') +# self.checkequal([''], '', 'split', 'aaa') +# self.checkequal(['aa'], 'aa', 'split', 'aaa') +# self.checkequal(['A', 'bobb'], 'Abbobbbobb', 'split', 'bbobb') +# self.checkequal(['A', 'B', ''], 'AbbobbBbbobb', 'split', 'bbobb') +# +# self.checkequal(['a']*20, ('aBLAH'*20)[:-4], 'split', 'BLAH') +# self.checkequal(['a']*20, ('aBLAH'*20)[:-4], 'split', 'BLAH', 19) +# self.checkequal(['a']*18 + ['aBLAHa'], ('aBLAH'*20)[:-4], +# 'split', 'BLAH', 18) +# +# # mixed use of str and unicode +# if self.type2test is not bytearray: +# result = [u'a', u'b', u'c d'] +# self.checkequal(result, 'a b c d', 'split', u' ', 2) +# +# # argument type +# self.checkraises(TypeError, 'hello', 'split', 42, 42, 42) +# +# # null case +# self.checkraises(ValueError, 'hello', 'split', '') +# self.checkraises(ValueError, 'hello', 'split', '', 0) + +# def test_rsplit(self): +# self.checkequal(['this', 'is', 'the', 'rsplit', 'function'], +# 'this is the rsplit function', 'rsplit') +# +# # by whitespace +# self.checkequal(['a', 'b', 'c', 'd'], 'a b c d ', 'rsplit') +# self.checkequal(['a b c', 'd'], 'a b c d', 'rsplit', None, 1) +# self.checkequal(['a b', 'c', 'd'], 'a b c d', 'rsplit', None, 2) +# self.checkequal(['a', 'b', 'c', 'd'], 'a b c d', 'rsplit', None, 3) +# self.checkequal(['a', 'b', 'c', 'd'], 'a b c d', 'rsplit', None, 4) +# self.checkequal(['a', 'b', 'c', 'd'], 'a b c d', 'rsplit', None, +# sys.maxint-20) +# self.checkequal(['a b c d'], 'a b c d', 'rsplit', None, 0) +# self.checkequal(['a b c d'], 'a b c d ', 'rsplit', None, 0) +# self.checkequal(['a b', 'c', 'd'], 'a b c d', 'rsplit', None, 2) +# +# self.checkequal([], ' ', 'rsplit') +# self.checkequal(['a'], ' a ', 'rsplit') +# self.checkequal(['a', 'b'], ' a b ', 'rsplit') +# self.checkequal([' a', 'b'], ' a b ', 'rsplit', None, 1) +# self.checkequal([' a b c'], ' a b c ', 'rsplit', +# None, 0) +# self.checkequal([' a b','c'], ' a b c ', 'rsplit', +# None, 1) +# self.checkequal([' a', 'b', 'c'], ' a b c ', 'rsplit', +# None, 2) +# self.checkequal(['a', 'b', 'c'], ' a b c ', 'rsplit', +# None, 3) +# self.checkequal(['a', 'b'], '\n\ta \t\r b \v ', 'rsplit', None, 88) +# aaa = ' a '*20 +# self.checkequal(['a']*20, aaa, 'rsplit') +# self.checkequal([aaa[:-4]] + ['a'], aaa, 'rsplit', None, 1) +# self.checkequal([' a a'] + ['a']*18, aaa, 'rsplit', None, 18) +# +# for b in ('arf\tbarf', 'arf\nbarf', 'arf\rbarf', +# 'arf\fbarf', 'arf\vbarf'): +# self.checkequal(['arf', 'barf'], b, 'rsplit') +# self.checkequal(['arf', 'barf'], b, 'rsplit', None) +# self.checkequal(['arf', 'barf'], b, 'rsplit', None, 2) +# +# # by a char +# self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'rsplit', '|') +# self.checkequal(['a|b|c', 'd'], 'a|b|c|d', 'rsplit', '|', 1) +# self.checkequal(['a|b', 'c', 'd'], 'a|b|c|d', 'rsplit', '|', 2) +# self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'rsplit', '|', 3) +# self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'rsplit', '|', 4) +# self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'rsplit', '|', +# sys.maxint-100) +# self.checkequal(['a|b|c|d'], 'a|b|c|d', 'rsplit', '|', 0) +# self.checkequal(['a||b||c', '', 'd'], 'a||b||c||d', 'rsplit', '|', 2) +# self.checkequal(['abcd'], 'abcd', 'rsplit', '|') +# self.checkequal([''], '', 'rsplit', '|') +# self.checkequal(['', ' begincase'], '| begincase', 'rsplit', '|') +# self.checkequal(['endcase ', ''], 'endcase |', 'rsplit', '|') +# self.checkequal(['', 'bothcase', ''], '|bothcase|', 'rsplit', '|') +# +# self.checkequal(['a\x00\x00b', 'c', 'd'], 'a\x00\x00b\x00c\x00d', 'rsplit', '\x00', 2) +# +# self.checkequal(['a']*20, ('a|'*20)[:-1], 'rsplit', '|') +# self.checkequal(['a|a|a|a|a']+['a']*15, +# ('a|'*20)[:-1], 'rsplit', '|', 15) +# +# # by string +# self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'rsplit', '//') +# self.checkequal(['a//b//c', 'd'], 'a//b//c//d', 'rsplit', '//', 1) +# self.checkequal(['a//b', 'c', 'd'], 'a//b//c//d', 'rsplit', '//', 2) +# self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'rsplit', '//', 3) +# self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'rsplit', '//', 4) +# self.checkequal(['a', 'b', 'c', 'd'], 'a//b//c//d', 'rsplit', '//', +# sys.maxint-5) +# self.checkequal(['a//b//c//d'], 'a//b//c//d', 'rsplit', '//', 0) +# self.checkequal(['a////b////c', '', 'd'], 'a////b////c////d', 'rsplit', '//', 2) +# self.checkequal(['', ' begincase'], 'test begincase', 'rsplit', 'test') +# self.checkequal(['endcase ', ''], 'endcase test', 'rsplit', 'test') +# self.checkequal(['', ' bothcase ', ''], 'test bothcase test', +# 'rsplit', 'test') +# self.checkequal(['ab', 'c'], 'abbbc', 'rsplit', 'bb') +# self.checkequal(['', ''], 'aaa', 'rsplit', 'aaa') +# self.checkequal(['aaa'], 'aaa', 'rsplit', 'aaa', 0) +# self.checkequal(['ab', 'ab'], 'abbaab', 'rsplit', 'ba') +# self.checkequal(['aaaa'], 'aaaa', 'rsplit', 'aab') +# self.checkequal([''], '', 'rsplit', 'aaa') +# self.checkequal(['aa'], 'aa', 'rsplit', 'aaa') +# self.checkequal(['bbob', 'A'], 'bbobbbobbA', 'rsplit', 'bbobb') +# self.checkequal(['', 'B', 'A'], 'bbobbBbbobbA', 'rsplit', 'bbobb') +# +# self.checkequal(['a']*20, ('aBLAH'*20)[:-4], 'rsplit', 'BLAH') +# self.checkequal(['a']*20, ('aBLAH'*20)[:-4], 'rsplit', 'BLAH', 19) +# self.checkequal(['aBLAHa'] + ['a']*18, ('aBLAH'*20)[:-4], +# 'rsplit', 'BLAH', 18) +# +# # mixed use of str and unicode +# if self.type2test is not bytearray: +# result = [u'a b', u'c', u'd'] +# self.checkequal(result, 'a b c d', 'rsplit', u' ', 2) +# +# # argument type +# self.checkraises(TypeError, 'hello', 'rsplit', 42, 42, 42) +# +# # null case +# self.checkraises(ValueError, 'hello', 'rsplit', '') +# self.checkraises(ValueError, 'hello', 'rsplit', '', 0) def test_strip_whitespace(self): self.checkequal('hello', ' hello ', 'strip') @@ -552,255 +552,255 @@ def test_strip_whitespace(self): self.checkequal(' hello', ' hello ', 'rstrip', None) self.checkequal('hello', 'hello', 'strip', None) - def test_strip(self): - # strip/lstrip/rstrip with str arg - self.checkequal('hello', 'xyzzyhelloxyzzy', 'strip', 'xyz') - self.checkequal('helloxyzzy', 'xyzzyhelloxyzzy', 'lstrip', 'xyz') - self.checkequal('xyzzyhello', 'xyzzyhelloxyzzy', 'rstrip', 'xyz') - self.checkequal('hello', 'hello', 'strip', 'xyz') - self.checkequal('', 'mississippi', 'strip', 'mississippi') - - # only trims the start and end, does not strip internal characters - self.checkequal('mississipp', 'mississippi', 'strip', 'i') - - # strip/lstrip/rstrip with unicode arg - if self.type2test is not bytearray and test_support.have_unicode: - self.checkequal(unicode('hello', 'ascii'), 'xyzzyhelloxyzzy', - 'strip', unicode('xyz', 'ascii')) - self.checkequal(unicode('helloxyzzy', 'ascii'), 'xyzzyhelloxyzzy', - 'lstrip', unicode('xyz', 'ascii')) - self.checkequal(unicode('xyzzyhello', 'ascii'), 'xyzzyhelloxyzzy', - 'rstrip', unicode('xyz', 'ascii')) - # XXX - #self.checkequal(unicode('hello', 'ascii'), 'hello', - # 'strip', unicode('xyz', 'ascii')) - - self.checkraises(TypeError, 'hello', 'strip', 42, 42) - self.checkraises(TypeError, 'hello', 'lstrip', 42, 42) - self.checkraises(TypeError, 'hello', 'rstrip', 42, 42) - - def test_ljust(self): - self.checkequal('abc ', 'abc', 'ljust', 10) - self.checkequal('abc ', 'abc', 'ljust', 6) - self.checkequal('abc', 'abc', 'ljust', 3) - self.checkequal('abc', 'abc', 'ljust', 2) - if self.type2test is bytearray: - # Special case because bytearray argument is not accepted - self.assertEqual(b'abc*******', bytearray(b'abc').ljust(10, '*')) - else: - self.checkequal('abc*******', 'abc', 'ljust', 10, '*') - self.checkraises(TypeError, 'abc', 'ljust') - - def test_rjust(self): - self.checkequal(' abc', 'abc', 'rjust', 10) - self.checkequal(' abc', 'abc', 'rjust', 6) - self.checkequal('abc', 'abc', 'rjust', 3) - self.checkequal('abc', 'abc', 'rjust', 2) - if self.type2test is bytearray: - # Special case because bytearray argument is not accepted - self.assertEqual(b'*******abc', bytearray(b'abc').rjust(10, '*')) - else: - self.checkequal('*******abc', 'abc', 'rjust', 10, '*') - self.checkraises(TypeError, 'abc', 'rjust') - - def test_center(self): - self.checkequal(' abc ', 'abc', 'center', 10) - self.checkequal(' abc ', 'abc', 'center', 6) - self.checkequal('abc', 'abc', 'center', 3) - self.checkequal('abc', 'abc', 'center', 2) - if self.type2test is bytearray: - # Special case because bytearray argument is not accepted - result = bytearray(b'abc').center(10, '*') - self.assertEqual(b'***abc****', result) - else: - self.checkequal('***abc****', 'abc', 'center', 10, '*') - self.checkraises(TypeError, 'abc', 'center') - - def test_swapcase(self): - self.checkequal('hEllO CoMPuTErS', 'HeLLo cOmpUteRs', 'swapcase') - - self.checkraises(TypeError, 'hello', 'swapcase', 42) - - def test_replace(self): - EQ = self.checkequal - - # Operations on the empty string - EQ("", "", "replace", "", "") - EQ("A", "", "replace", "", "A") - EQ("", "", "replace", "A", "") - EQ("", "", "replace", "A", "A") - EQ("", "", "replace", "", "", 100) - EQ("", "", "replace", "", "", sys.maxint) - - # interleave (from=="", 'to' gets inserted everywhere) - EQ("A", "A", "replace", "", "") - EQ("*A*", "A", "replace", "", "*") - EQ("*1A*1", "A", "replace", "", "*1") - EQ("*-#A*-#", "A", "replace", "", "*-#") - EQ("*-A*-A*-", "AA", "replace", "", "*-") - EQ("*-A*-A*-", "AA", "replace", "", "*-", -1) - EQ("*-A*-A*-", "AA", "replace", "", "*-", sys.maxint) - EQ("*-A*-A*-", "AA", "replace", "", "*-", 4) - EQ("*-A*-A*-", "AA", "replace", "", "*-", 3) - EQ("*-A*-A", "AA", "replace", "", "*-", 2) - EQ("*-AA", "AA", "replace", "", "*-", 1) - EQ("AA", "AA", "replace", "", "*-", 0) - - # single character deletion (from=="A", to=="") - EQ("", "A", "replace", "A", "") - EQ("", "AAA", "replace", "A", "") - EQ("", "AAA", "replace", "A", "", -1) - EQ("", "AAA", "replace", "A", "", sys.maxint) - EQ("", "AAA", "replace", "A", "", 4) - EQ("", "AAA", "replace", "A", "", 3) - EQ("A", "AAA", "replace", "A", "", 2) - EQ("AA", "AAA", "replace", "A", "", 1) - EQ("AAA", "AAA", "replace", "A", "", 0) - EQ("", "AAAAAAAAAA", "replace", "A", "") - EQ("BCD", "ABACADA", "replace", "A", "") - EQ("BCD", "ABACADA", "replace", "A", "", -1) - EQ("BCD", "ABACADA", "replace", "A", "", sys.maxint) - EQ("BCD", "ABACADA", "replace", "A", "", 5) - EQ("BCD", "ABACADA", "replace", "A", "", 4) - EQ("BCDA", "ABACADA", "replace", "A", "", 3) - EQ("BCADA", "ABACADA", "replace", "A", "", 2) - EQ("BACADA", "ABACADA", "replace", "A", "", 1) - EQ("ABACADA", "ABACADA", "replace", "A", "", 0) - EQ("BCD", "ABCAD", "replace", "A", "") - EQ("BCD", "ABCADAA", "replace", "A", "") - EQ("BCD", "BCD", "replace", "A", "") - EQ("*************", "*************", "replace", "A", "") - EQ("^A^", "^"+"A"*1000+"^", "replace", "A", "", 999) - - # substring deletion (from=="the", to=="") - EQ("", "the", "replace", "the", "") - EQ("ater", "theater", "replace", "the", "") - EQ("", "thethe", "replace", "the", "") - EQ("", "thethethethe", "replace", "the", "") - EQ("aaaa", "theatheatheathea", "replace", "the", "") - EQ("that", "that", "replace", "the", "") - EQ("thaet", "thaet", "replace", "the", "") - EQ("here and re", "here and there", "replace", "the", "") - EQ("here and re and re", "here and there and there", - "replace", "the", "", sys.maxint) - EQ("here and re and re", "here and there and there", - "replace", "the", "", -1) - EQ("here and re and re", "here and there and there", - "replace", "the", "", 3) - EQ("here and re and re", "here and there and there", - "replace", "the", "", 2) - EQ("here and re and there", "here and there and there", - "replace", "the", "", 1) - EQ("here and there and there", "here and there and there", - "replace", "the", "", 0) - EQ("here and re and re", "here and there and there", "replace", "the", "") - - EQ("abc", "abc", "replace", "the", "") - EQ("abcdefg", "abcdefg", "replace", "the", "") - - # substring deletion (from=="bob", to=="") - EQ("bob", "bbobob", "replace", "bob", "") - EQ("bobXbob", "bbobobXbbobob", "replace", "bob", "") - EQ("aaaaaaa", "aaaaaaabob", "replace", "bob", "") - EQ("aaaaaaa", "aaaaaaa", "replace", "bob", "") - - # single character replace in place (len(from)==len(to)==1) - EQ("Who goes there?", "Who goes there?", "replace", "o", "o") - EQ("WhO gOes there?", "Who goes there?", "replace", "o", "O") - EQ("WhO gOes there?", "Who goes there?", "replace", "o", "O", sys.maxint) - EQ("WhO gOes there?", "Who goes there?", "replace", "o", "O", -1) - EQ("WhO gOes there?", "Who goes there?", "replace", "o", "O", 3) - EQ("WhO gOes there?", "Who goes there?", "replace", "o", "O", 2) - EQ("WhO goes there?", "Who goes there?", "replace", "o", "O", 1) - EQ("Who goes there?", "Who goes there?", "replace", "o", "O", 0) - - EQ("Who goes there?", "Who goes there?", "replace", "a", "q") - EQ("who goes there?", "Who goes there?", "replace", "W", "w") - EQ("wwho goes there?ww", "WWho goes there?WW", "replace", "W", "w") - EQ("Who goes there!", "Who goes there?", "replace", "?", "!") - EQ("Who goes there!!", "Who goes there??", "replace", "?", "!") - - EQ("Who goes there?", "Who goes there?", "replace", ".", "!") - - # substring replace in place (len(from)==len(to) > 1) - EQ("Th** ** a t**sue", "This is a tissue", "replace", "is", "**") - EQ("Th** ** a t**sue", "This is a tissue", "replace", "is", "**", sys.maxint) - EQ("Th** ** a t**sue", "This is a tissue", "replace", "is", "**", -1) - EQ("Th** ** a t**sue", "This is a tissue", "replace", "is", "**", 4) - EQ("Th** ** a t**sue", "This is a tissue", "replace", "is", "**", 3) - EQ("Th** ** a tissue", "This is a tissue", "replace", "is", "**", 2) - EQ("Th** is a tissue", "This is a tissue", "replace", "is", "**", 1) - EQ("This is a tissue", "This is a tissue", "replace", "is", "**", 0) - EQ("cobob", "bobob", "replace", "bob", "cob") - EQ("cobobXcobocob", "bobobXbobobob", "replace", "bob", "cob") - EQ("bobob", "bobob", "replace", "bot", "bot") - - # replace single character (len(from)==1, len(to)>1) - EQ("ReyKKjaviKK", "Reykjavik", "replace", "k", "KK") - EQ("ReyKKjaviKK", "Reykjavik", "replace", "k", "KK", -1) - EQ("ReyKKjaviKK", "Reykjavik", "replace", "k", "KK", sys.maxint) - EQ("ReyKKjaviKK", "Reykjavik", "replace", "k", "KK", 2) - EQ("ReyKKjavik", "Reykjavik", "replace", "k", "KK", 1) - EQ("Reykjavik", "Reykjavik", "replace", "k", "KK", 0) - EQ("A----B----C----", "A.B.C.", "replace", ".", "----") - - EQ("Reykjavik", "Reykjavik", "replace", "q", "KK") - - # replace substring (len(from)>1, len(to)!=len(from)) - EQ("ham, ham, eggs and ham", "spam, spam, eggs and spam", - "replace", "spam", "ham") - EQ("ham, ham, eggs and ham", "spam, spam, eggs and spam", - "replace", "spam", "ham", sys.maxint) - EQ("ham, ham, eggs and ham", "spam, spam, eggs and spam", - "replace", "spam", "ham", -1) - EQ("ham, ham, eggs and ham", "spam, spam, eggs and spam", - "replace", "spam", "ham", 4) - EQ("ham, ham, eggs and ham", "spam, spam, eggs and spam", - "replace", "spam", "ham", 3) - EQ("ham, ham, eggs and spam", "spam, spam, eggs and spam", - "replace", "spam", "ham", 2) - EQ("ham, spam, eggs and spam", "spam, spam, eggs and spam", - "replace", "spam", "ham", 1) - EQ("spam, spam, eggs and spam", "spam, spam, eggs and spam", - "replace", "spam", "ham", 0) - - EQ("bobob", "bobobob", "replace", "bobob", "bob") - EQ("bobobXbobob", "bobobobXbobobob", "replace", "bobob", "bob") - EQ("BOBOBOB", "BOBOBOB", "replace", "bob", "bobby") - - with test_support.check_py3k_warnings(): - ba = buffer('a') - bb = buffer('b') - EQ("bbc", "abc", "replace", ba, bb) - EQ("aac", "abc", "replace", bb, ba) - - # - self.checkequal('one@two!three!', 'one!two!three!', 'replace', '!', '@', 1) - self.checkequal('onetwothree', 'one!two!three!', 'replace', '!', '') - self.checkequal('one@two@three!', 'one!two!three!', 'replace', '!', '@', 2) - self.checkequal('one@two@three@', 'one!two!three!', 'replace', '!', '@', 3) - self.checkequal('one@two@three@', 'one!two!three!', 'replace', '!', '@', 4) - self.checkequal('one!two!three!', 'one!two!three!', 'replace', '!', '@', 0) - self.checkequal('one@two@three@', 'one!two!three!', 'replace', '!', '@') - self.checkequal('one!two!three!', 'one!two!three!', 'replace', 'x', '@') - self.checkequal('one!two!three!', 'one!two!three!', 'replace', 'x', '@', 2) - self.checkequal('-a-b-c-', 'abc', 'replace', '', '-') - self.checkequal('-a-b-c', 'abc', 'replace', '', '-', 3) - self.checkequal('abc', 'abc', 'replace', '', '-', 0) - self.checkequal('', '', 'replace', '', '') - self.checkequal('abc', 'abc', 'replace', 'ab', '--', 0) - self.checkequal('abc', 'abc', 'replace', 'xy', '--') - # Next three for SF bug 422088: [OSF1 alpha] string.replace(); died with - # MemoryError due to empty result (platform malloc issue when requesting - # 0 bytes). - self.checkequal('', '123', 'replace', '123', '') - self.checkequal('', '123123', 'replace', '123', '') - self.checkequal('x', '123x123', 'replace', '123', '') - - self.checkraises(TypeError, 'hello', 'replace') - self.checkraises(TypeError, 'hello', 'replace', 42) - self.checkraises(TypeError, 'hello', 'replace', 42, 'h') - self.checkraises(TypeError, 'hello', 'replace', 'h', 42) +# def test_strip(self): +# # strip/lstrip/rstrip with str arg +# self.checkequal('hello', 'xyzzyhelloxyzzy', 'strip', 'xyz') +# self.checkequal('helloxyzzy', 'xyzzyhelloxyzzy', 'lstrip', 'xyz') +# self.checkequal('xyzzyhello', 'xyzzyhelloxyzzy', 'rstrip', 'xyz') +# self.checkequal('hello', 'hello', 'strip', 'xyz') +# self.checkequal('', 'mississippi', 'strip', 'mississippi') +# +# # only trims the start and end, does not strip internal characters +# self.checkequal('mississipp', 'mississippi', 'strip', 'i') +# +# # strip/lstrip/rstrip with unicode arg +# if self.type2test is not bytearray and test_support.have_unicode: +# self.checkequal(unicode('hello', 'ascii'), 'xyzzyhelloxyzzy', +# 'strip', unicode('xyz', 'ascii')) +# self.checkequal(unicode('helloxyzzy', 'ascii'), 'xyzzyhelloxyzzy', +# 'lstrip', unicode('xyz', 'ascii')) +# self.checkequal(unicode('xyzzyhello', 'ascii'), 'xyzzyhelloxyzzy', +# 'rstrip', unicode('xyz', 'ascii')) +# # XXX +# #self.checkequal(unicode('hello', 'ascii'), 'hello', +# # 'strip', unicode('xyz', 'ascii')) +# +# self.checkraises(TypeError, 'hello', 'strip', 42, 42) +# self.checkraises(TypeError, 'hello', 'lstrip', 42, 42) +# self.checkraises(TypeError, 'hello', 'rstrip', 42, 42) + +# def test_ljust(self): +# self.checkequal('abc ', 'abc', 'ljust', 10) +# self.checkequal('abc ', 'abc', 'ljust', 6) +# self.checkequal('abc', 'abc', 'ljust', 3) +# self.checkequal('abc', 'abc', 'ljust', 2) +# if self.type2test is bytearray: +# # Special case because bytearray argument is not accepted +# self.assertEqual(b'abc*******', bytearray(b'abc').ljust(10, '*')) +# else: +# self.checkequal('abc*******', 'abc', 'ljust', 10, '*') +# self.checkraises(TypeError, 'abc', 'ljust') + +# def test_rjust(self): +# self.checkequal(' abc', 'abc', 'rjust', 10) +# self.checkequal(' abc', 'abc', 'rjust', 6) +# self.checkequal('abc', 'abc', 'rjust', 3) +# self.checkequal('abc', 'abc', 'rjust', 2) +# if self.type2test is bytearray: +# # Special case because bytearray argument is not accepted +# self.assertEqual(b'*******abc', bytearray(b'abc').rjust(10, '*')) +# else: +# self.checkequal('*******abc', 'abc', 'rjust', 10, '*') +# self.checkraises(TypeError, 'abc', 'rjust') + +# def test_center(self): +# self.checkequal(' abc ', 'abc', 'center', 10) +# self.checkequal(' abc ', 'abc', 'center', 6) +# self.checkequal('abc', 'abc', 'center', 3) +# self.checkequal('abc', 'abc', 'center', 2) +# if self.type2test is bytearray: +# # Special case because bytearray argument is not accepted +# result = bytearray(b'abc').center(10, '*') +# self.assertEqual(b'***abc****', result) +# else: +# self.checkequal('***abc****', 'abc', 'center', 10, '*') +# self.checkraises(TypeError, 'abc', 'center') + +# def test_swapcase(self): +# self.checkequal('hEllO CoMPuTErS', 'HeLLo cOmpUteRs', 'swapcase') +# +# self.checkraises(TypeError, 'hello', 'swapcase', 42) + +# def test_replace(self): +# EQ = self.checkequal +# +# # Operations on the empty string +# EQ("", "", "replace", "", "") +# EQ("A", "", "replace", "", "A") +# EQ("", "", "replace", "A", "") +# EQ("", "", "replace", "A", "A") +# EQ("", "", "replace", "", "", 100) +# EQ("", "", "replace", "", "", sys.maxint) +# +# # interleave (from=="", 'to' gets inserted everywhere) +# EQ("A", "A", "replace", "", "") +# EQ("*A*", "A", "replace", "", "*") +# EQ("*1A*1", "A", "replace", "", "*1") +# EQ("*-#A*-#", "A", "replace", "", "*-#") +# EQ("*-A*-A*-", "AA", "replace", "", "*-") +# EQ("*-A*-A*-", "AA", "replace", "", "*-", -1) +# EQ("*-A*-A*-", "AA", "replace", "", "*-", sys.maxint) +# EQ("*-A*-A*-", "AA", "replace", "", "*-", 4) +# EQ("*-A*-A*-", "AA", "replace", "", "*-", 3) +# EQ("*-A*-A", "AA", "replace", "", "*-", 2) +# EQ("*-AA", "AA", "replace", "", "*-", 1) +# EQ("AA", "AA", "replace", "", "*-", 0) +# +# # single character deletion (from=="A", to=="") +# EQ("", "A", "replace", "A", "") +# EQ("", "AAA", "replace", "A", "") +# EQ("", "AAA", "replace", "A", "", -1) +# EQ("", "AAA", "replace", "A", "", sys.maxint) +# EQ("", "AAA", "replace", "A", "", 4) +# EQ("", "AAA", "replace", "A", "", 3) +# EQ("A", "AAA", "replace", "A", "", 2) +# EQ("AA", "AAA", "replace", "A", "", 1) +# EQ("AAA", "AAA", "replace", "A", "", 0) +# EQ("", "AAAAAAAAAA", "replace", "A", "") +# EQ("BCD", "ABACADA", "replace", "A", "") +# EQ("BCD", "ABACADA", "replace", "A", "", -1) +# EQ("BCD", "ABACADA", "replace", "A", "", sys.maxint) +# EQ("BCD", "ABACADA", "replace", "A", "", 5) +# EQ("BCD", "ABACADA", "replace", "A", "", 4) +# EQ("BCDA", "ABACADA", "replace", "A", "", 3) +# EQ("BCADA", "ABACADA", "replace", "A", "", 2) +# EQ("BACADA", "ABACADA", "replace", "A", "", 1) +# EQ("ABACADA", "ABACADA", "replace", "A", "", 0) +# EQ("BCD", "ABCAD", "replace", "A", "") +# EQ("BCD", "ABCADAA", "replace", "A", "") +# EQ("BCD", "BCD", "replace", "A", "") +# EQ("*************", "*************", "replace", "A", "") +# EQ("^A^", "^"+"A"*1000+"^", "replace", "A", "", 999) +# +# # substring deletion (from=="the", to=="") +# EQ("", "the", "replace", "the", "") +# EQ("ater", "theater", "replace", "the", "") +# EQ("", "thethe", "replace", "the", "") +# EQ("", "thethethethe", "replace", "the", "") +# EQ("aaaa", "theatheatheathea", "replace", "the", "") +# EQ("that", "that", "replace", "the", "") +# EQ("thaet", "thaet", "replace", "the", "") +# EQ("here and re", "here and there", "replace", "the", "") +# EQ("here and re and re", "here and there and there", +# "replace", "the", "", sys.maxint) +# EQ("here and re and re", "here and there and there", +# "replace", "the", "", -1) +# EQ("here and re and re", "here and there and there", +# "replace", "the", "", 3) +# EQ("here and re and re", "here and there and there", +# "replace", "the", "", 2) +# EQ("here and re and there", "here and there and there", +# "replace", "the", "", 1) +# EQ("here and there and there", "here and there and there", +# "replace", "the", "", 0) +# EQ("here and re and re", "here and there and there", "replace", "the", "") +# +# EQ("abc", "abc", "replace", "the", "") +# EQ("abcdefg", "abcdefg", "replace", "the", "") +# +# # substring deletion (from=="bob", to=="") +# EQ("bob", "bbobob", "replace", "bob", "") +# EQ("bobXbob", "bbobobXbbobob", "replace", "bob", "") +# EQ("aaaaaaa", "aaaaaaabob", "replace", "bob", "") +# EQ("aaaaaaa", "aaaaaaa", "replace", "bob", "") +# +# # single character replace in place (len(from)==len(to)==1) +# EQ("Who goes there?", "Who goes there?", "replace", "o", "o") +# EQ("WhO gOes there?", "Who goes there?", "replace", "o", "O") +# EQ("WhO gOes there?", "Who goes there?", "replace", "o", "O", sys.maxint) +# EQ("WhO gOes there?", "Who goes there?", "replace", "o", "O", -1) +# EQ("WhO gOes there?", "Who goes there?", "replace", "o", "O", 3) +# EQ("WhO gOes there?", "Who goes there?", "replace", "o", "O", 2) +# EQ("WhO goes there?", "Who goes there?", "replace", "o", "O", 1) +# EQ("Who goes there?", "Who goes there?", "replace", "o", "O", 0) +# +# EQ("Who goes there?", "Who goes there?", "replace", "a", "q") +# EQ("who goes there?", "Who goes there?", "replace", "W", "w") +# EQ("wwho goes there?ww", "WWho goes there?WW", "replace", "W", "w") +# EQ("Who goes there!", "Who goes there?", "replace", "?", "!") +# EQ("Who goes there!!", "Who goes there??", "replace", "?", "!") +# +# EQ("Who goes there?", "Who goes there?", "replace", ".", "!") +# +# # substring replace in place (len(from)==len(to) > 1) +# EQ("Th** ** a t**sue", "This is a tissue", "replace", "is", "**") +# EQ("Th** ** a t**sue", "This is a tissue", "replace", "is", "**", sys.maxint) +# EQ("Th** ** a t**sue", "This is a tissue", "replace", "is", "**", -1) +# EQ("Th** ** a t**sue", "This is a tissue", "replace", "is", "**", 4) +# EQ("Th** ** a t**sue", "This is a tissue", "replace", "is", "**", 3) +# EQ("Th** ** a tissue", "This is a tissue", "replace", "is", "**", 2) +# EQ("Th** is a tissue", "This is a tissue", "replace", "is", "**", 1) +# EQ("This is a tissue", "This is a tissue", "replace", "is", "**", 0) +# EQ("cobob", "bobob", "replace", "bob", "cob") +# EQ("cobobXcobocob", "bobobXbobobob", "replace", "bob", "cob") +# EQ("bobob", "bobob", "replace", "bot", "bot") +# +# # replace single character (len(from)==1, len(to)>1) +# EQ("ReyKKjaviKK", "Reykjavik", "replace", "k", "KK") +# EQ("ReyKKjaviKK", "Reykjavik", "replace", "k", "KK", -1) +# EQ("ReyKKjaviKK", "Reykjavik", "replace", "k", "KK", sys.maxint) +# EQ("ReyKKjaviKK", "Reykjavik", "replace", "k", "KK", 2) +# EQ("ReyKKjavik", "Reykjavik", "replace", "k", "KK", 1) +# EQ("Reykjavik", "Reykjavik", "replace", "k", "KK", 0) +# EQ("A----B----C----", "A.B.C.", "replace", ".", "----") +# +# EQ("Reykjavik", "Reykjavik", "replace", "q", "KK") +# +# # replace substring (len(from)>1, len(to)!=len(from)) +# EQ("ham, ham, eggs and ham", "spam, spam, eggs and spam", +# "replace", "spam", "ham") +# EQ("ham, ham, eggs and ham", "spam, spam, eggs and spam", +# "replace", "spam", "ham", sys.maxint) +# EQ("ham, ham, eggs and ham", "spam, spam, eggs and spam", +# "replace", "spam", "ham", -1) +# EQ("ham, ham, eggs and ham", "spam, spam, eggs and spam", +# "replace", "spam", "ham", 4) +# EQ("ham, ham, eggs and ham", "spam, spam, eggs and spam", +# "replace", "spam", "ham", 3) +# EQ("ham, ham, eggs and spam", "spam, spam, eggs and spam", +# "replace", "spam", "ham", 2) +# EQ("ham, spam, eggs and spam", "spam, spam, eggs and spam", +# "replace", "spam", "ham", 1) +# EQ("spam, spam, eggs and spam", "spam, spam, eggs and spam", +# "replace", "spam", "ham", 0) +# +# EQ("bobob", "bobobob", "replace", "bobob", "bob") +# EQ("bobobXbobob", "bobobobXbobobob", "replace", "bobob", "bob") +# EQ("BOBOBOB", "BOBOBOB", "replace", "bob", "bobby") +# +# with test_support.check_py3k_warnings(): +# ba = buffer('a') +# bb = buffer('b') +# EQ("bbc", "abc", "replace", ba, bb) +# EQ("aac", "abc", "replace", bb, ba) +# +# # +# self.checkequal('one@two!three!', 'one!two!three!', 'replace', '!', '@', 1) +# self.checkequal('onetwothree', 'one!two!three!', 'replace', '!', '') +# self.checkequal('one@two@three!', 'one!two!three!', 'replace', '!', '@', 2) +# self.checkequal('one@two@three@', 'one!two!three!', 'replace', '!', '@', 3) +# self.checkequal('one@two@three@', 'one!two!three!', 'replace', '!', '@', 4) +# self.checkequal('one!two!three!', 'one!two!three!', 'replace', '!', '@', 0) +# self.checkequal('one@two@three@', 'one!two!three!', 'replace', '!', '@') +# self.checkequal('one!two!three!', 'one!two!three!', 'replace', 'x', '@') +# self.checkequal('one!two!three!', 'one!two!three!', 'replace', 'x', '@', 2) +# self.checkequal('-a-b-c-', 'abc', 'replace', '', '-') +# self.checkequal('-a-b-c', 'abc', 'replace', '', '-', 3) +# self.checkequal('abc', 'abc', 'replace', '', '-', 0) +# self.checkequal('', '', 'replace', '', '') +# self.checkequal('abc', 'abc', 'replace', 'ab', '--', 0) +# self.checkequal('abc', 'abc', 'replace', 'xy', '--') +# # Next three for SF bug 422088: [OSF1 alpha] string.replace(); died with +# # MemoryError due to empty result (platform malloc issue when requesting +# # 0 bytes). +# self.checkequal('', '123', 'replace', '123', '') +# self.checkequal('', '123123', 'replace', '123', '') +# self.checkequal('x', '123x123', 'replace', '123', '') +# +# self.checkraises(TypeError, 'hello', 'replace') +# self.checkraises(TypeError, 'hello', 'replace', 42) +# self.checkraises(TypeError, 'hello', 'replace', 42, 'h') +# self.checkraises(TypeError, 'hello', 'replace', 'h', 42) @unittest.skipIf(sys.maxint > (1 << 32) or struct.calcsize('P') != 4, 'only applies to 32-bit platforms') @@ -811,21 +811,21 @@ def test_replace_overflow(self): self.checkraises(OverflowError, A2_16, "replace", "A", A2_16) self.checkraises(OverflowError, A2_16, "replace", "AA", A2_16+A2_16) - def test_zfill(self): - self.checkequal('123', '123', 'zfill', 2) - self.checkequal('123', '123', 'zfill', 3) - self.checkequal('0123', '123', 'zfill', 4) - self.checkequal('+123', '+123', 'zfill', 3) - self.checkequal('+123', '+123', 'zfill', 4) - self.checkequal('+0123', '+123', 'zfill', 5) - self.checkequal('-123', '-123', 'zfill', 3) - self.checkequal('-123', '-123', 'zfill', 4) - self.checkequal('-0123', '-123', 'zfill', 5) - self.checkequal('000', '', 'zfill', 3) - self.checkequal('34', '34', 'zfill', 1) - self.checkequal('0034', '34', 'zfill', 4) - - self.checkraises(TypeError, '123', 'zfill') +# def test_zfill(self): +# self.checkequal('123', '123', 'zfill', 2) +# self.checkequal('123', '123', 'zfill', 3) +# self.checkequal('0123', '123', 'zfill', 4) +# self.checkequal('+123', '+123', 'zfill', 3) +# self.checkequal('+123', '+123', 'zfill', 4) +# self.checkequal('+0123', '+123', 'zfill', 5) +# self.checkequal('-123', '-123', 'zfill', 3) +# self.checkequal('-123', '-123', 'zfill', 4) +# self.checkequal('-0123', '-123', 'zfill', 5) +# self.checkequal('000', '', 'zfill', 3) +# self.checkequal('34', '34', 'zfill', 1) +# self.checkequal('0034', '34', 'zfill', 4) +# +# self.checkraises(TypeError, '123', 'zfill') class NonStringModuleTest(object): @@ -1327,19 +1327,19 @@ def test_maketrans(self): ) self.assertRaises(ValueError, string.maketrans, 'abc', 'xyzw') - def test_translate(self): - table = string.maketrans('abc', 'xyz') - self.checkequal('xyzxyz', 'xyzabcdef', 'translate', table, 'def') - - table = string.maketrans('a', 'A') - self.checkequal('Abc', 'abc', 'translate', table) - self.checkequal('xyz', 'xyz', 'translate', table) - self.checkequal('yz', 'xyz', 'translate', table, 'x') - self.checkequal('yx', 'zyzzx', 'translate', None, 'z') - self.checkequal('zyzzx', 'zyzzx', 'translate', None, '') - self.checkequal('zyzzx', 'zyzzx', 'translate', None) - self.checkraises(ValueError, 'xyz', 'translate', 'too short', 'strip') - self.checkraises(ValueError, 'xyz', 'translate', 'too short') +# def test_translate(self): +# table = string.maketrans('abc', 'xyz') +# self.checkequal('xyzxyz', 'xyzabcdef', 'translate', table, 'def') +# +# table = string.maketrans('a', 'A') +# self.checkequal('Abc', 'abc', 'translate', table) +# self.checkequal('xyz', 'xyz', 'translate', table) +# self.checkequal('yz', 'xyz', 'translate', table, 'x') +# self.checkequal('yx', 'zyzzx', 'translate', None, 'z') +# self.checkequal('zyzzx', 'zyzzx', 'translate', None, '') +# self.checkequal('zyzzx', 'zyzzx', 'translate', None) +# self.checkraises(ValueError, 'xyz', 'translate', 'too short', 'strip') +# self.checkraises(ValueError, 'xyz', 'translate', 'too short') class MixinStrUserStringTest(object): diff --git a/third_party/stdlib/test/test_bisect.py b/third_party/stdlib/test/test_bisect.py new file mode 100644 index 00000000..ffc1fa79 --- /dev/null +++ b/third_party/stdlib/test/test_bisect.py @@ -0,0 +1,386 @@ +import sys +import unittest +from test import test_support +#from UserList import UserList +import UserList as _UserList +UserList = _UserList.UserList + +# We do a bit of trickery here to be able to test both the C implementation +# and the Python implementation of the module. + +# Make it impossible to import the C implementation anymore. +sys.modules['_bisect'] = 0 +# We must also handle the case that bisect was imported before. +if 'bisect' in sys.modules: + del sys.modules['bisect'] + +# Now we can import the module and get the pure Python implementation. +import bisect as py_bisect + +# Restore everything to normal. +del sys.modules['_bisect'] +del sys.modules['bisect'] + +# This is now the module with the C implementation. +#import bisect as c_bisect + + +class Range(object): + """A trivial xrange()-like object without any integer width limitations.""" + def __init__(self, start, stop): + self.start = start + self.stop = stop + self.last_insert = None + + def __len__(self): + return self.stop - self.start + + def __getitem__(self, idx): + n = self.stop - self.start + if idx < 0: + idx += n + if idx >= n: + raise IndexError(idx) + return self.start + idx + + def insert(self, idx, item): + self.last_insert = idx, item + + +class TestBisect(unittest.TestCase): + # module = None + module = py_bisect + + def setUp(self): + self.precomputedCases = [ + (self.module.bisect_right, [], 1, 0), + (self.module.bisect_right, [1], 0, 0), + (self.module.bisect_right, [1], 1, 1), + (self.module.bisect_right, [1], 2, 1), + (self.module.bisect_right, [1, 1], 0, 0), + (self.module.bisect_right, [1, 1], 1, 2), + (self.module.bisect_right, [1, 1], 2, 2), + (self.module.bisect_right, [1, 1, 1], 0, 0), + (self.module.bisect_right, [1, 1, 1], 1, 3), + (self.module.bisect_right, [1, 1, 1], 2, 3), + (self.module.bisect_right, [1, 1, 1, 1], 0, 0), + (self.module.bisect_right, [1, 1, 1, 1], 1, 4), + (self.module.bisect_right, [1, 1, 1, 1], 2, 4), + (self.module.bisect_right, [1, 2], 0, 0), + (self.module.bisect_right, [1, 2], 1, 1), + (self.module.bisect_right, [1, 2], 1.5, 1), + (self.module.bisect_right, [1, 2], 2, 2), + (self.module.bisect_right, [1, 2], 3, 2), + (self.module.bisect_right, [1, 1, 2, 2], 0, 0), + (self.module.bisect_right, [1, 1, 2, 2], 1, 2), + (self.module.bisect_right, [1, 1, 2, 2], 1.5, 2), + (self.module.bisect_right, [1, 1, 2, 2], 2, 4), + (self.module.bisect_right, [1, 1, 2, 2], 3, 4), + (self.module.bisect_right, [1, 2, 3], 0, 0), + (self.module.bisect_right, [1, 2, 3], 1, 1), + (self.module.bisect_right, [1, 2, 3], 1.5, 1), + (self.module.bisect_right, [1, 2, 3], 2, 2), + (self.module.bisect_right, [1, 2, 3], 2.5, 2), + (self.module.bisect_right, [1, 2, 3], 3, 3), + (self.module.bisect_right, [1, 2, 3], 4, 3), + (self.module.bisect_right, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 0, 0), + (self.module.bisect_right, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 1, 1), + (self.module.bisect_right, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 1.5, 1), + (self.module.bisect_right, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 2, 3), + (self.module.bisect_right, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 2.5, 3), + (self.module.bisect_right, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 3, 6), + (self.module.bisect_right, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 3.5, 6), + (self.module.bisect_right, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 4, 10), + (self.module.bisect_right, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 5, 10), + + (self.module.bisect_left, [], 1, 0), + (self.module.bisect_left, [1], 0, 0), + (self.module.bisect_left, [1], 1, 0), + (self.module.bisect_left, [1], 2, 1), + (self.module.bisect_left, [1, 1], 0, 0), + (self.module.bisect_left, [1, 1], 1, 0), + (self.module.bisect_left, [1, 1], 2, 2), + (self.module.bisect_left, [1, 1, 1], 0, 0), + (self.module.bisect_left, [1, 1, 1], 1, 0), + (self.module.bisect_left, [1, 1, 1], 2, 3), + (self.module.bisect_left, [1, 1, 1, 1], 0, 0), + (self.module.bisect_left, [1, 1, 1, 1], 1, 0), + (self.module.bisect_left, [1, 1, 1, 1], 2, 4), + (self.module.bisect_left, [1, 2], 0, 0), + (self.module.bisect_left, [1, 2], 1, 0), + (self.module.bisect_left, [1, 2], 1.5, 1), + (self.module.bisect_left, [1, 2], 2, 1), + (self.module.bisect_left, [1, 2], 3, 2), + (self.module.bisect_left, [1, 1, 2, 2], 0, 0), + (self.module.bisect_left, [1, 1, 2, 2], 1, 0), + (self.module.bisect_left, [1, 1, 2, 2], 1.5, 2), + (self.module.bisect_left, [1, 1, 2, 2], 2, 2), + (self.module.bisect_left, [1, 1, 2, 2], 3, 4), + (self.module.bisect_left, [1, 2, 3], 0, 0), + (self.module.bisect_left, [1, 2, 3], 1, 0), + (self.module.bisect_left, [1, 2, 3], 1.5, 1), + (self.module.bisect_left, [1, 2, 3], 2, 1), + (self.module.bisect_left, [1, 2, 3], 2.5, 2), + (self.module.bisect_left, [1, 2, 3], 3, 2), + (self.module.bisect_left, [1, 2, 3], 4, 3), + (self.module.bisect_left, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 0, 0), + (self.module.bisect_left, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 1, 0), + (self.module.bisect_left, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 1.5, 1), + (self.module.bisect_left, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 2, 1), + (self.module.bisect_left, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 2.5, 3), + (self.module.bisect_left, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 3, 3), + (self.module.bisect_left, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 3.5, 6), + (self.module.bisect_left, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 4, 6), + (self.module.bisect_left, [1, 2, 2, 3, 3, 3, 4, 4, 4, 4], 5, 10) + ] + + def test_precomputed(self): + for func, data, elem, expected in self.precomputedCases: + self.assertEqual(func(data, elem), expected) + self.assertEqual(func(UserList(data), elem), expected) + + def test_negative_lo(self): + # Issue 3301 + mod = self.module + self.assertRaises(ValueError, mod.bisect_left, [1, 2, 3], 5, -1, 3), + self.assertRaises(ValueError, mod.bisect_right, [1, 2, 3], 5, -1, 3), + self.assertRaises(ValueError, mod.insort_left, [1, 2, 3], 5, -1, 3), + self.assertRaises(ValueError, mod.insort_right, [1, 2, 3], 5, -1, 3), + + def test_large_range(self): + # Issue 13496 + mod = self.module + n = sys.maxsize + try: + data = xrange(n-1) + except OverflowError: + self.skipTest("can't create a xrange() object of size `sys.maxsize`") + self.assertEqual(mod.bisect_left(data, n-3), n-3) + self.assertEqual(mod.bisect_right(data, n-3), n-2) + self.assertEqual(mod.bisect_left(data, n-3, n-10, n), n-3) + self.assertEqual(mod.bisect_right(data, n-3, n-10, n), n-2) + + def test_large_pyrange(self): + # Same as above, but without C-imposed limits on range() parameters + mod = self.module + n = sys.maxsize + data = Range(0, n-1) + self.assertEqual(mod.bisect_left(data, n-3), n-3) + self.assertEqual(mod.bisect_right(data, n-3), n-2) + self.assertEqual(mod.bisect_left(data, n-3, n-10, n), n-3) + self.assertEqual(mod.bisect_right(data, n-3, n-10, n), n-2) + x = n - 100 + mod.insort_left(data, x, x - 50, x + 50) + self.assertEqual(data.last_insert, (x, x)) + x = n - 200 + mod.insort_right(data, x, x - 50, x + 50) + self.assertEqual(data.last_insert, (x + 1, x)) + + def test_random(self, n=25): + #from random import randrange + import random as _random + randrange = _random.randrange + + for i in xrange(n): + data = [randrange(0, n, 2) for j in xrange(i)] + data.sort() + elem = randrange(-1, n+1) + ip = self.module.bisect_left(data, elem) + if ip < len(data): + self.assertTrue(elem <= data[ip]) + if ip > 0: + self.assertTrue(data[ip-1] < elem) + ip = self.module.bisect_right(data, elem) + if ip < len(data): + self.assertTrue(elem < data[ip]) + if ip > 0: + self.assertTrue(data[ip-1] <= elem) + + def test_optionalSlicing(self): + for func, data, elem, expected in self.precomputedCases: + for lo in xrange(4): + lo = min(len(data), lo) + for hi in xrange(3,8): + hi = min(len(data), hi) + ip = func(data, elem, lo, hi) + self.assertTrue(lo <= ip <= hi) + if func is self.module.bisect_left and ip < hi: + self.assertTrue(elem <= data[ip]) + if func is self.module.bisect_left and ip > lo: + self.assertTrue(data[ip-1] < elem) + if func is self.module.bisect_right and ip < hi: + self.assertTrue(elem < data[ip]) + if func is self.module.bisect_right and ip > lo: + self.assertTrue(data[ip-1] <= elem) + self.assertEqual(ip, max(lo, min(hi, expected))) + + def test_backcompatibility(self): + self.assertEqual(self.module.bisect, self.module.bisect_right) + + def test_keyword_args(self): + data = [10, 20, 30, 40, 50] + self.assertEqual(self.module.bisect_left(a=data, x=25, lo=1, hi=3), 2) + self.assertEqual(self.module.bisect_right(a=data, x=25, lo=1, hi=3), 2) + self.assertEqual(self.module.bisect(a=data, x=25, lo=1, hi=3), 2) + self.module.insort_left(a=data, x=25, lo=1, hi=3) + self.module.insort_right(a=data, x=25, lo=1, hi=3) + self.module.insort(a=data, x=25, lo=1, hi=3) + self.assertEqual(data, [10, 20, 25, 25, 25, 30, 40, 50]) + +# class TestBisectPython(TestBisect): +# module = py_bisect + +# class TestBisectC(TestBisect): +# module = c_bisect + +#============================================================================== + +class TestInsort(unittest.TestCase): + # module = None + module = py_bisect + + def test_vsBuiltinSort(self, n=500): + #from random import choice + import random as _random + choice = _random.choice + + for insorted in (list(), UserList()): + for i in xrange(n): + digit = choice("0123456789") + if digit in "02468": + f = self.module.insort_left + else: + f = self.module.insort_right + f(insorted, digit) + self.assertEqual(sorted(insorted), insorted) + + def test_backcompatibility(self): + self.assertEqual(self.module.insort, self.module.insort_right) + + def test_listDerived(self): + class List(list): + data = [] + def insert(self, index, item): + self.data.insert(index, item) + + lst = List() + self.module.insort_left(lst, 10) + self.module.insort_right(lst, 5) + self.assertEqual([5, 10], lst.data) + +# class TestInsortPython(TestInsort): +# module = py_bisect + +# class TestInsortC(TestInsort): +# module = c_bisect + +#============================================================================== + + +class LenOnly(object): + "Dummy sequence class defining __len__ but not __getitem__." + def __len__(self): + return 10 + + # Have to define LenOnly as an object for the Grumpy runtime. Doing so will + # raise a TypeError instead of an AttributeError when accessing __getitem__, + # so we redefine __getitem__ to raise an AttributeError. + def __getitem__(self, ndx): + raise AttributeError + +class GetOnly(object): + "Dummy sequence class defining __getitem__ but not __len__." + def __getitem__(self, ndx): + return 10 + + def __len__(self): + raise AttributeError + +class CmpErr(object): + "Dummy element that always raises an error during comparison" + def __cmp__(self, other): + raise ZeroDivisionError + +class TestErrorHandling(unittest.TestCase): + # module = None + module = py_bisect + + def test_non_sequence(self): + for f in (self.module.bisect_left, self.module.bisect_right, + self.module.insort_left, self.module.insort_right): + self.assertRaises(TypeError, f, 10, 10) + + def test_len_only(self): + for f in (self.module.bisect_left, self.module.bisect_right, + self.module.insort_left, self.module.insort_right): + self.assertRaises(AttributeError, f, LenOnly(), 10) + + def test_get_only(self): + for f in (self.module.bisect_left, self.module.bisect_right, + self.module.insort_left, self.module.insort_right): + self.assertRaises(AttributeError, f, GetOnly(), 10) + + def test_cmp_err(self): + seq = [CmpErr(), CmpErr(), CmpErr()] + for f in (self.module.bisect_left, self.module.bisect_right, + self.module.insort_left, self.module.insort_right): + self.assertRaises(ZeroDivisionError, f, seq, 10) + + def test_arg_parsing(self): + for f in (self.module.bisect_left, self.module.bisect_right, + self.module.insort_left, self.module.insort_right): + self.assertRaises(TypeError, f, 10) + +# class TestErrorHandlingPython(TestErrorHandling): +# module = py_bisect + +# class TestErrorHandlingC(TestErrorHandling): +# module = c_bisect + +#============================================================================== + +libreftest = """ +Example from the Library Reference: Doc/library/bisect.rst +The bisect() function is generally useful for categorizing numeric data. +This example uses bisect() to look up a letter grade for an exam total +(say) based on a set of ordered numeric breakpoints: 85 and up is an `A', +75..84 is a `B', etc. + >>> grades = "FEDCBA" + >>> breakpoints = [30, 44, 66, 75, 85] + >>> from bisect import bisect + >>> def grade(total): + ... return grades[bisect(breakpoints, total)] + ... + >>> grade(66) + 'C' + >>> map(grade, [33, 99, 77, 44, 12, 88]) + ['E', 'A', 'B', 'D', 'F', 'A'] +""" + +#------------------------------------------------------------------------------ + +__test__ = {'libreftest' : libreftest} + +def test_main(verbose=None): + # from test import test_bisect + + # test_classes = [TestBisectPython, TestBisectC, + # TestInsortPython, TestInsortC, + # TestErrorHandlingPython, TestErrorHandlingC] + test_classes = [TestBisect, TestInsort, TestErrorHandling] + + test_support.run_unittest(*test_classes) + # test_support.run_doctest(test_bisect, verbose) + + # verify reference counting + if verbose and hasattr(sys, "gettotalrefcount"): + #import gc + counts = [None] * 5 + for i in xrange(len(counts)): + test_support.run_unittest(*test_classes) + #gc.collect() + counts[i] = sys.gettotalrefcount() + print counts + +if __name__ == "__main__": + test_main(verbose=True) diff --git a/third_party/stdlib/test/test_colorsys.py b/third_party/stdlib/test/test_colorsys.py new file mode 100644 index 00000000..5c8860ff --- /dev/null +++ b/third_party/stdlib/test/test_colorsys.py @@ -0,0 +1,104 @@ +import unittest +import colorsys +from test import test_support + +def frange(start, stop, step): + while start <= stop: + yield start + start += step + +class ColorsysTest(unittest.TestCase): + + def assertTripleEqual(self, tr1, tr2): + self.assertEqual(len(tr1), 3) + self.assertEqual(len(tr2), 3) + self.assertAlmostEqual(tr1[0], tr2[0]) + self.assertAlmostEqual(tr1[1], tr2[1]) + self.assertAlmostEqual(tr1[2], tr2[2]) + + def test_hsv_roundtrip(self): + for r in frange(0.0, 1.0, 0.2): + for g in frange(0.0, 1.0, 0.2): + for b in frange(0.0, 1.0, 0.2): + rgb = (r, g, b) + self.assertTripleEqual( + rgb, + colorsys.hsv_to_rgb(*colorsys.rgb_to_hsv(*rgb)) + ) + + def test_hsv_values(self): + values = [ + # rgb, hsv + ((0.0, 0.0, 0.0), ( 0 , 0.0, 0.0)), # black + ((0.0, 0.0, 1.0), (4./6., 1.0, 1.0)), # blue + ((0.0, 1.0, 0.0), (2./6., 1.0, 1.0)), # green + ((0.0, 1.0, 1.0), (3./6., 1.0, 1.0)), # cyan + ((1.0, 0.0, 0.0), ( 0 , 1.0, 1.0)), # red + ((1.0, 0.0, 1.0), (5./6., 1.0, 1.0)), # purple + ((1.0, 1.0, 0.0), (1./6., 1.0, 1.0)), # yellow + ((1.0, 1.0, 1.0), ( 0 , 0.0, 1.0)), # white + ((0.5, 0.5, 0.5), ( 0 , 0.0, 0.5)), # grey + ] + for (rgb, hsv) in values: + self.assertTripleEqual(hsv, colorsys.rgb_to_hsv(*rgb)) + self.assertTripleEqual(rgb, colorsys.hsv_to_rgb(*hsv)) + + def test_hls_roundtrip(self): + for r in frange(0.0, 1.0, 0.2): + for g in frange(0.0, 1.0, 0.2): + for b in frange(0.0, 1.0, 0.2): + rgb = (r, g, b) + self.assertTripleEqual( + rgb, + colorsys.hls_to_rgb(*colorsys.rgb_to_hls(*rgb)) + ) + + def test_hls_values(self): + values = [ + # rgb, hls + ((0.0, 0.0, 0.0), ( 0 , 0.0, 0.0)), # black + ((0.0, 0.0, 1.0), (4./6., 0.5, 1.0)), # blue + ((0.0, 1.0, 0.0), (2./6., 0.5, 1.0)), # green + ((0.0, 1.0, 1.0), (3./6., 0.5, 1.0)), # cyan + ((1.0, 0.0, 0.0), ( 0 , 0.5, 1.0)), # red + ((1.0, 0.0, 1.0), (5./6., 0.5, 1.0)), # purple + ((1.0, 1.0, 0.0), (1./6., 0.5, 1.0)), # yellow + ((1.0, 1.0, 1.0), ( 0 , 1.0, 0.0)), # white + ((0.5, 0.5, 0.5), ( 0 , 0.5, 0.0)), # grey + ] + for (rgb, hls) in values: + self.assertTripleEqual(hls, colorsys.rgb_to_hls(*rgb)) + self.assertTripleEqual(rgb, colorsys.hls_to_rgb(*hls)) + + def test_yiq_roundtrip(self): + for r in frange(0.0, 1.0, 0.2): + for g in frange(0.0, 1.0, 0.2): + for b in frange(0.0, 1.0, 0.2): + rgb = (r, g, b) + self.assertTripleEqual( + rgb, + colorsys.yiq_to_rgb(*colorsys.rgb_to_yiq(*rgb)) + ) + + def test_yiq_values(self): + values = [ + # rgb, yiq + ((0.0, 0.0, 0.0), (0.0, 0.0, 0.0)), # black + ((0.0, 0.0, 1.0), (0.11, -0.3217, 0.3121)), # blue + ((0.0, 1.0, 0.0), (0.59, -0.2773, -0.5251)), # green + ((0.0, 1.0, 1.0), (0.7, -0.599, -0.213)), # cyan + ((1.0, 0.0, 0.0), (0.3, 0.599, 0.213)), # red + ((1.0, 0.0, 1.0), (0.41, 0.2773, 0.5251)), # purple + ((1.0, 1.0, 0.0), (0.89, 0.3217, -0.3121)), # yellow + ((1.0, 1.0, 1.0), (1.0, 0.0, 0.0)), # white + ((0.5, 0.5, 0.5), (0.5, 0.0, 0.0)), # grey + ] + for (rgb, yiq) in values: + self.assertTripleEqual(yiq, colorsys.rgb_to_yiq(*rgb)) + self.assertTripleEqual(rgb, colorsys.yiq_to_rgb(*yiq)) + +def test_main(): + test_support.run_unittest(ColorsysTest) + +if __name__ == "__main__": + test_main() diff --git a/third_party/stdlib/test/test_datetime.py b/third_party/stdlib/test/test_datetime.py new file mode 100644 index 00000000..b7bf2afd --- /dev/null +++ b/third_party/stdlib/test/test_datetime.py @@ -0,0 +1,3420 @@ +"""Test date/time type. + +See http://www.zope.org/Members/fdrake/DateTimeWiki/TestCases +""" +# from __future__ import division +import sys +# import pickle +# import cPickle +import unittest + +from test import test_support + +# from datetime import MINYEAR, MAXYEAR +# from datetime import timedelta +# from datetime import tzinfo +# from datetime import time +# from datetime import date, datetime +import datetime +MINYEAR, MAXYEAR, timedelta, tzinfo, time, date, datetime = \ + datetime.MINYEAR, datetime.MAXYEAR, datetime.timedelta, datetime.tzinfo, \ + datetime.time, datetime.date, datetime.datetime + +# pickle_choices = [(pickler, unpickler, proto) +# for pickler in pickle, cPickle +# for unpickler in pickle, cPickle +# for proto in range(3)] +# assert len(pickle_choices) == 2*2*3 + +# An arbitrary collection of objects of non-datetime types, for testing +# mixed-type comparisons. +OTHERSTUFF = (10, 10L, 34.5, "abc", {}, [], ()) + + +############################################################################# +# module tests + +class TestModule(unittest.TestCase): + + def test_constants(self): + import datetime + self.assertEqual(datetime.MINYEAR, 1) + self.assertEqual(datetime.MAXYEAR, 9999) + +############################################################################# +# tzinfo tests + +class FixedOffset(tzinfo): + def __init__(self, offset, name, dstoffset=42): + if isinstance(offset, int): + offset = timedelta(minutes=offset) + if isinstance(dstoffset, int): + dstoffset = timedelta(minutes=dstoffset) + self.__offset = offset + self.__name = name + self.__dstoffset = dstoffset + def __repr__(self): + return self.__name.lower() + def utcoffset(self, dt): + return self.__offset + def tzname(self, dt): + return self.__name + def dst(self, dt): + return self.__dstoffset + +class PicklableFixedOffset(FixedOffset): + def __init__(self, offset=None, name=None, dstoffset=None): + FixedOffset.__init__(self, offset, name, dstoffset) + +class TestTZInfo(unittest.TestCase): + + def test_non_abstractness(self): + # In order to allow subclasses to get pickled, the C implementation + # wasn't able to get away with having __init__ raise + # NotImplementedError. + useless = tzinfo() + dt = datetime.max + self.assertRaises(NotImplementedError, useless.tzname, dt) + self.assertRaises(NotImplementedError, useless.utcoffset, dt) + self.assertRaises(NotImplementedError, useless.dst, dt) + + def test_subclass_must_override(self): + class NotEnough(tzinfo): + def __init__(self, offset, name): + self.__offset = offset + self.__name = name + self.assertTrue(issubclass(NotEnough, tzinfo)) + ne = NotEnough(3, "NotByALongShot") + self.assertIsInstance(ne, tzinfo) + + dt = datetime.now() + self.assertRaises(NotImplementedError, ne.tzname, dt) + self.assertRaises(NotImplementedError, ne.utcoffset, dt) + self.assertRaises(NotImplementedError, ne.dst, dt) + + def test_normal(self): + fo = FixedOffset(3, "Three") + self.assertIsInstance(fo, tzinfo) + for dt in datetime.now(), None: + self.assertEqual(fo.utcoffset(dt), timedelta(minutes=3)) + self.assertEqual(fo.tzname(dt), "Three") + self.assertEqual(fo.dst(dt), timedelta(minutes=42)) + + # def test_pickling_base(self): + # # There's no point to pickling tzinfo objects on their own (they + # # carry no data), but they need to be picklable anyway else + # # concrete subclasses can't be pickled. + # orig = tzinfo.__new__(tzinfo) + # self.assertIs(type(orig), tzinfo) + # for pickler, unpickler, proto in pickle_choices: + # green = pickler.dumps(orig, proto) + # derived = unpickler.loads(green) + # self.assertIs(type(derived), tzinfo) + + # def test_pickling_subclass(self): + # # Make sure we can pickle/unpickle an instance of a subclass. + # offset = timedelta(minutes=-300) + # orig = PicklableFixedOffset(offset, 'cookie') + # self.assertIsInstance(orig, tzinfo) + # self.assertTrue(type(orig) is PicklableFixedOffset) + # self.assertEqual(orig.utcoffset(None), offset) + # self.assertEqual(orig.tzname(None), 'cookie') + # for pickler, unpickler, proto in pickle_choices: + # green = pickler.dumps(orig, proto) + # derived = unpickler.loads(green) + # self.assertIsInstance(derived, tzinfo) + # self.assertTrue(type(derived) is PicklableFixedOffset) + # self.assertEqual(derived.utcoffset(None), offset) + # self.assertEqual(derived.tzname(None), 'cookie') + +############################################################################# +# Base class for testing a particular aspect of timedelta, time, date and +# datetime comparisons. + +class HarmlessMixedComparison(object): + # Test that __eq__ and __ne__ don't complain for mixed-type comparisons. + + # Subclasses must define 'theclass', and theclass(1, 1, 1) must be a + # legit constructor. + + def test_harmless_mixed_comparison(self): + me = self.theclass(1, 1, 1) + + self.assertFalse(me == ()) + self.assertTrue(me != ()) + self.assertFalse(() == me) + self.assertTrue(() != me) + + self.assertIn(me, [1, 20L, [], me]) + self.assertIn([], [me, 1, 20L, []]) + + def test_harmful_mixed_comparison(self): + me = self.theclass(1, 1, 1) + + self.assertRaises(TypeError, lambda: me < ()) + self.assertRaises(TypeError, lambda: me <= ()) + self.assertRaises(TypeError, lambda: me > ()) + self.assertRaises(TypeError, lambda: me >= ()) + + self.assertRaises(TypeError, lambda: () < me) + self.assertRaises(TypeError, lambda: () <= me) + self.assertRaises(TypeError, lambda: () > me) + self.assertRaises(TypeError, lambda: () >= me) + + self.assertRaises(TypeError, cmp, (), me) + self.assertRaises(TypeError, cmp, me, ()) + +############################################################################# +# timedelta tests + +class TestTimeDelta(HarmlessMixedComparison, unittest.TestCase): + + theclass = timedelta + + def test_constructor(self): + eq = self.assertEqual + td = timedelta + + # Check keyword args to constructor + eq(td(), td(weeks=0, days=0, hours=0, minutes=0, seconds=0, + milliseconds=0, microseconds=0)) + eq(td(1), td(days=1)) + eq(td(0, 1), td(seconds=1)) + eq(td(0, 0, 1), td(microseconds=1)) + eq(td(weeks=1), td(days=7)) + eq(td(days=1), td(hours=24)) + eq(td(hours=1), td(minutes=60)) + eq(td(minutes=1), td(seconds=60)) + eq(td(seconds=1), td(milliseconds=1000)) + eq(td(milliseconds=1), td(microseconds=1000)) + + # Check float args to constructor + eq(td(weeks=1.0/7), td(days=1)) + eq(td(days=1.0/24), td(hours=1)) + eq(td(hours=1.0/60), td(minutes=1)) + eq(td(minutes=1.0/60), td(seconds=1)) + eq(td(seconds=0.001), td(milliseconds=1)) + eq(td(milliseconds=0.001), td(microseconds=1)) + + def test_computations(self): + eq = self.assertEqual + td = timedelta + + a = td(7) # One week + b = td(0, 60) # One minute + c = td(0, 0, 1000) # One millisecond + eq(a+b+c, td(7, 60, 1000)) + eq(a-b, td(6, 24*3600 - 60)) + eq(-a, td(-7)) + # eq(+a, td(7)) + eq(-b, td(-1, 24*3600 - 60)) + eq(-c, td(-1, 24*3600 - 1, 999000)) + eq(abs(a), a) + eq(abs(-a), a) + eq(td(6, 24*3600), a) + eq(td(0, 0, 60*1000000), b) + eq(a*10, td(70)) + eq(a*10, 10*a) + eq(a*10L, 10*a) + eq(b*10, td(0, 600)) + eq(10*b, td(0, 600)) + eq(b*10L, td(0, 600)) + eq(c*10, td(0, 0, 10000)) + eq(10*c, td(0, 0, 10000)) + eq(c*10L, td(0, 0, 10000)) + eq(a*-1, -a) + eq(b*-2, -b-b) + eq(c*-2, -c+-c) + eq(b*(60*24), (b*60)*24) + eq(b*(60*24), (60*b)*24) + eq(c*1000, td(0, 1)) + eq(1000*c, td(0, 1)) + eq(a//7, td(1)) + eq(b//10, td(0, 6)) + eq(c//1000, td(0, 0, 1)) + eq(a//10, td(0, 7*24*360)) + eq(a//3600000, td(0, 0, 7*24*1000)) + + # Issue #11576 + eq(td(999999999, 86399, 999999) - td(999999999, 86399, 999998), + td(0, 0, 1)) + eq(td(999999999, 1, 1) - td(999999999, 1, 0), + td(0, 0, 1)) + + + def test_disallowed_computations(self): + a = timedelta(42) + + # Add/sub ints, longs, floats should be illegal + for i in 1, 1L, 1.0: + self.assertRaises(TypeError, lambda: a+i) + self.assertRaises(TypeError, lambda: a-i) + self.assertRaises(TypeError, lambda: i+a) + self.assertRaises(TypeError, lambda: i-a) + + # Mul/div by float isn't supported. + x = 2.3 + self.assertRaises(TypeError, lambda: a*x) + self.assertRaises(TypeError, lambda: x*a) + self.assertRaises(TypeError, lambda: a/x) + self.assertRaises(TypeError, lambda: x/a) + self.assertRaises(TypeError, lambda: a // x) + self.assertRaises(TypeError, lambda: x // a) + + # Division of int by timedelta doesn't make sense. + # Division by zero doesn't make sense. + for zero in 0, 0L: + self.assertRaises(TypeError, lambda: zero // a) + self.assertRaises(ZeroDivisionError, lambda: a // zero) + + def test_basic_attributes(self): + days, seconds, us = 1, 7, 31 + td = timedelta(days, seconds, us) + self.assertEqual(td.days, days) + self.assertEqual(td.seconds, seconds) + self.assertEqual(td.microseconds, us) + + @unittest.expectedFailure + def test_total_seconds(self): + td = timedelta(days=365) + self.assertEqual(td.total_seconds(), 31536000.0) + for total_seconds in [123456.789012, -123456.789012, 0.123456, 0, 1e6]: + td = timedelta(seconds=total_seconds) + self.assertEqual(td.total_seconds(), total_seconds) + # Issue8644: Test that td.total_seconds() has the same + # accuracy as td / timedelta(seconds=1). + for ms in [-1, -2, -123]: + td = timedelta(microseconds=ms) + self.assertEqual(td.total_seconds(), + ((24*3600*td.days + td.seconds)*10**6 + + td.microseconds)/10**6) + + def test_carries(self): + t1 = timedelta(days=100, + weeks=-7, + hours=-24*(100-49), + minutes=-3, + seconds=12, + microseconds=(3*60 - 12) * 1e6 + 1) + t2 = timedelta(microseconds=1) + self.assertEqual(t1, t2) + + @unittest.expectedFailure + def test_hash_equality(self): + t1 = timedelta(days=100, + weeks=-7, + hours=-24*(100-49), + minutes=-3, + seconds=12, + microseconds=(3*60 - 12) * 1000000) + t2 = timedelta() + self.assertEqual(hash(t1), hash(t2)) + + t1 += timedelta(weeks=7) + t2 += timedelta(days=7*7) + self.assertEqual(t1, t2) + self.assertEqual(hash(t1), hash(t2)) + + d = {t1: 1} + d[t2] = 2 + self.assertEqual(len(d), 1) + self.assertEqual(d[t1], 2) + + # def test_pickling(self): + # args = 12, 34, 56 + # orig = timedelta(*args) + # for pickler, unpickler, proto in pickle_choices: + # green = pickler.dumps(orig, proto) + # derived = unpickler.loads(green) + # self.assertEqual(orig, derived) + + def test_compare(self): + t1 = timedelta(2, 3, 4) + t2 = timedelta(2, 3, 4) + self.assertTrue(t1 == t2) + self.assertTrue(t1 <= t2) + self.assertTrue(t1 >= t2) + self.assertFalse(t1 != t2) + self.assertFalse(t1 < t2) + self.assertFalse(t1 > t2) + self.assertEqual(cmp(t1, t2), 0) + self.assertEqual(cmp(t2, t1), 0) + + for args in (3, 3, 3), (2, 4, 4), (2, 3, 5): + t2 = timedelta(*args) # this is larger than t1 + self.assertTrue(t1 < t2) + self.assertTrue(t2 > t1) + self.assertTrue(t1 <= t2) + self.assertTrue(t2 >= t1) + self.assertTrue(t1 != t2) + self.assertTrue(t2 != t1) + self.assertFalse(t1 == t2) + self.assertFalse(t2 == t1) + self.assertFalse(t1 > t2) + self.assertFalse(t2 < t1) + self.assertFalse(t1 >= t2) + self.assertFalse(t2 <= t1) + self.assertEqual(cmp(t1, t2), -1) + self.assertEqual(cmp(t2, t1), 1) + + for badarg in OTHERSTUFF: + self.assertEqual(t1 == badarg, False) + self.assertEqual(t1 != badarg, True) + self.assertEqual(badarg == t1, False) + self.assertEqual(badarg != t1, True) + + self.assertRaises(TypeError, lambda: t1 <= badarg) + self.assertRaises(TypeError, lambda: t1 < badarg) + self.assertRaises(TypeError, lambda: t1 > badarg) + self.assertRaises(TypeError, lambda: t1 >= badarg) + self.assertRaises(TypeError, lambda: badarg <= t1) + self.assertRaises(TypeError, lambda: badarg < t1) + self.assertRaises(TypeError, lambda: badarg > t1) + self.assertRaises(TypeError, lambda: badarg >= t1) + + def test_str(self): + td = timedelta + eq = self.assertEqual + + eq(str(td(1)), "1 day, 0:00:00") + eq(str(td(-1)), "-1 day, 0:00:00") + eq(str(td(2)), "2 days, 0:00:00") + eq(str(td(-2)), "-2 days, 0:00:00") + + eq(str(td(hours=12, minutes=58, seconds=59)), "12:58:59") + eq(str(td(hours=2, minutes=3, seconds=4)), "2:03:04") + eq(str(td(weeks=-30, hours=23, minutes=12, seconds=34)), + "-210 days, 23:12:34") + + eq(str(td(milliseconds=1)), "0:00:00.001000") + eq(str(td(microseconds=3)), "0:00:00.000003") + + eq(str(td(days=999999999, hours=23, minutes=59, seconds=59, + microseconds=999999)), + "999999999 days, 23:59:59.999999") + + @unittest.expectedFailure + def test_roundtrip(self): + for td in (timedelta(days=999999999, hours=23, minutes=59, + seconds=59, microseconds=999999), + timedelta(days=-999999999), + timedelta(days=1, seconds=2, microseconds=3)): + + # Verify td -> string -> td identity. + s = repr(td) + self.assertTrue(s.startswith('datetime.')) + s = s[9:] + td2 = eval(s) + self.assertEqual(td, td2) + + # Verify identity via reconstructing from pieces. + td2 = timedelta(td.days, td.seconds, td.microseconds) + self.assertEqual(td, td2) + + def test_resolution_info(self): + self.assertIsInstance(timedelta.min, timedelta) + self.assertIsInstance(timedelta.max, timedelta) + self.assertIsInstance(timedelta.resolution, timedelta) + self.assertTrue(timedelta.max > timedelta.min) + self.assertEqual(timedelta.min, timedelta(-999999999)) + self.assertEqual(timedelta.max, timedelta(999999999, 24*3600-1, 1e6-1)) + self.assertEqual(timedelta.resolution, timedelta(0, 0, 1)) + + def test_overflow(self): + tiny = timedelta.resolution + + td = timedelta.min + tiny + td -= tiny # no problem + self.assertRaises(OverflowError, td.__sub__, tiny) + self.assertRaises(OverflowError, td.__add__, -tiny) + + td = timedelta.max - tiny + td += tiny # no problem + self.assertRaises(OverflowError, td.__add__, tiny) + self.assertRaises(OverflowError, td.__sub__, -tiny) + + self.assertRaises(OverflowError, lambda: -timedelta.max) + + def test_microsecond_rounding(self): + td = timedelta + eq = self.assertEqual + + # Single-field rounding. + eq(td(milliseconds=0.4/1000), td(0)) # rounds to 0 + eq(td(milliseconds=-0.4/1000), td(0)) # rounds to 0 + eq(td(milliseconds=0.6/1000), td(microseconds=1)) + eq(td(milliseconds=-0.6/1000), td(microseconds=-1)) + + # Rounding due to contributions from more than one field. + us_per_hour = 3600e6 + us_per_day = us_per_hour * 24 + eq(td(days=.4/us_per_day), td(0)) + eq(td(hours=.2/us_per_hour), td(0)) + eq(td(days=.4/us_per_day, hours=.2/us_per_hour), td(microseconds=1)) + + eq(td(days=-.4/us_per_day), td(0)) + eq(td(hours=-.2/us_per_hour), td(0)) + eq(td(days=-.4/us_per_day, hours=-.2/us_per_hour), td(microseconds=-1)) + + def test_massive_normalization(self): + td = timedelta(microseconds=-1) + self.assertEqual((td.days, td.seconds, td.microseconds), + (-1, 24*3600-1, 999999)) + + def test_bool(self): + self.assertTrue(timedelta(1)) + self.assertTrue(timedelta(0, 1)) + self.assertTrue(timedelta(0, 0, 1)) + self.assertTrue(timedelta(microseconds=1)) + self.assertFalse(timedelta(0)) + + def test_subclass_timedelta(self): + + class T(timedelta): + @staticmethod + def from_td(td): + return T(td.days, td.seconds, td.microseconds) + + def as_hours(self): + sum = (self.days * 24 + + self.seconds / 3600.0 + + self.microseconds / 3600e6) + return round(sum) + + t1 = T(days=1) + self.assertIs(type(t1), T) + self.assertEqual(t1.as_hours(), 24) + + t2 = T(days=-1, seconds=-3600) + self.assertIs(type(t2), T) + self.assertEqual(t2.as_hours(), -25) + + t3 = t1 + t2 + self.assertIs(type(t3), timedelta) + t4 = T.from_td(t3) + self.assertIs(type(t4), T) + self.assertEqual(t3.days, t4.days) + self.assertEqual(t3.seconds, t4.seconds) + self.assertEqual(t3.microseconds, t4.microseconds) + self.assertEqual(str(t3), str(t4)) + self.assertEqual(t4.as_hours(), -1) + +############################################################################# +# date tests + +class TestDateOnly(unittest.TestCase): + # Tests here won't pass if also run on datetime objects, so don't + # subclass this to test datetimes too. + + def test_delta_non_days_ignored(self): + dt = date(2000, 1, 2) + delta = timedelta(days=1, hours=2, minutes=3, seconds=4, + microseconds=5) + days = timedelta(delta.days) + self.assertEqual(days, timedelta(1)) + + dt2 = dt + delta + self.assertEqual(dt2, dt + days) + + dt2 = delta + dt + self.assertEqual(dt2, dt + days) + + dt2 = dt - delta + self.assertEqual(dt2, dt - days) + + delta = -delta + days = timedelta(delta.days) + self.assertEqual(days, timedelta(-2)) + + dt2 = dt + delta + self.assertEqual(dt2, dt + days) + + dt2 = delta + dt + self.assertEqual(dt2, dt + days) + + dt2 = dt - delta + self.assertEqual(dt2, dt - days) + +class SubclassDate(date): + sub_var = 1 + +class TestDate(HarmlessMixedComparison, unittest.TestCase): + # Tests here should pass for both dates and datetimes, except for a + # few tests that TestDateTime overrides. + + theclass = date + + def test_basic_attributes(self): + dt = self.theclass(2002, 3, 1) + self.assertEqual(dt.year, 2002) + self.assertEqual(dt.month, 3) + self.assertEqual(dt.day, 1) + + @unittest.expectedFailure + def test_roundtrip(self): + for dt in (self.theclass(1, 2, 3), + self.theclass.today()): + # Verify dt -> string -> date identity. + s = repr(dt) + self.assertTrue(s.startswith('datetime.')) + s = s[9:] + dt2 = eval(s) + self.assertEqual(dt, dt2) + + # Verify identity via reconstructing from pieces. + dt2 = self.theclass(dt.year, dt.month, dt.day) + self.assertEqual(dt, dt2) + + def test_ordinal_conversions(self): + # Check some fixed values. + for y, m, d, n in [(1, 1, 1, 1), # calendar origin + (1, 12, 31, 365), + (2, 1, 1, 366), + # first example from "Calendrical Calculations" + (1945, 11, 12, 710347)]: + d = self.theclass(y, m, d) + self.assertEqual(n, d.toordinal()) + fromord = self.theclass.fromordinal(n) + self.assertEqual(d, fromord) + if hasattr(fromord, "hour"): + # if we're checking something fancier than a date, verify + # the extra fields have been zeroed out + self.assertEqual(fromord.hour, 0) + self.assertEqual(fromord.minute, 0) + self.assertEqual(fromord.second, 0) + self.assertEqual(fromord.microsecond, 0) + + # Check first and last days of year spottily across the whole + # range of years supported. + for year in xrange(MINYEAR, MAXYEAR+1, 7): + # Verify (year, 1, 1) -> ordinal -> y, m, d is identity. + d = self.theclass(year, 1, 1) + n = d.toordinal() + d2 = self.theclass.fromordinal(n) + self.assertEqual(d, d2) + # Verify that moving back a day gets to the end of year-1. + if year > 1: + d = self.theclass.fromordinal(n-1) + d2 = self.theclass(year-1, 12, 31) + self.assertEqual(d, d2) + self.assertEqual(d2.toordinal(), n-1) + + # Test every day in a leap-year and a non-leap year. + dim = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] + for year, isleap in (2000, True), (2002, False): + n = self.theclass(year, 1, 1).toordinal() + for month, maxday in zip(range(1, 13), dim): + if month == 2 and isleap: + maxday += 1 + for day in range(1, maxday+1): + d = self.theclass(year, month, day) + self.assertEqual(d.toordinal(), n) + self.assertEqual(d, self.theclass.fromordinal(n)) + n += 1 + + def test_extreme_ordinals(self): + a = self.theclass.min + a = self.theclass(a.year, a.month, a.day) # get rid of time parts + aord = a.toordinal() + b = a.fromordinal(aord) + self.assertEqual(a, b) + + self.assertRaises(ValueError, lambda: a.fromordinal(aord - 1)) + + b = a + timedelta(days=1) + self.assertEqual(b.toordinal(), aord + 1) + self.assertEqual(b, self.theclass.fromordinal(aord + 1)) + + a = self.theclass.max + a = self.theclass(a.year, a.month, a.day) # get rid of time parts + aord = a.toordinal() + b = a.fromordinal(aord) + self.assertEqual(a, b) + + self.assertRaises(ValueError, lambda: a.fromordinal(aord + 1)) + + b = a - timedelta(days=1) + self.assertEqual(b.toordinal(), aord - 1) + self.assertEqual(b, self.theclass.fromordinal(aord - 1)) + + def test_bad_constructor_arguments(self): + # bad years + self.theclass(MINYEAR, 1, 1) # no exception + self.theclass(MAXYEAR, 1, 1) # no exception + self.assertRaises(ValueError, self.theclass, MINYEAR-1, 1, 1) + self.assertRaises(ValueError, self.theclass, MAXYEAR+1, 1, 1) + # bad months + self.theclass(2000, 1, 1) # no exception + self.theclass(2000, 12, 1) # no exception + self.assertRaises(ValueError, self.theclass, 2000, 0, 1) + self.assertRaises(ValueError, self.theclass, 2000, 13, 1) + # bad days + self.theclass(2000, 2, 29) # no exception + self.theclass(2004, 2, 29) # no exception + self.theclass(2400, 2, 29) # no exception + self.assertRaises(ValueError, self.theclass, 2000, 2, 30) + self.assertRaises(ValueError, self.theclass, 2001, 2, 29) + self.assertRaises(ValueError, self.theclass, 2100, 2, 29) + self.assertRaises(ValueError, self.theclass, 1900, 2, 29) + self.assertRaises(ValueError, self.theclass, 2000, 1, 0) + self.assertRaises(ValueError, self.theclass, 2000, 1, 32) + + @unittest.expectedFailure + def test_hash_equality(self): + d = self.theclass(2000, 12, 31) + # same thing + e = self.theclass(2000, 12, 31) + self.assertEqual(d, e) + self.assertEqual(hash(d), hash(e)) + + dic = {d: 1} + dic[e] = 2 + self.assertEqual(len(dic), 1) + self.assertEqual(dic[d], 2) + self.assertEqual(dic[e], 2) + + d = self.theclass(2001, 1, 1) + # same thing + e = self.theclass(2001, 1, 1) + self.assertEqual(d, e) + self.assertEqual(hash(d), hash(e)) + + dic = {d: 1} + dic[e] = 2 + self.assertEqual(len(dic), 1) + self.assertEqual(dic[d], 2) + self.assertEqual(dic[e], 2) + + def test_computations(self): + a = self.theclass(2002, 1, 31) + b = self.theclass(1956, 1, 31) + + diff = a-b + self.assertEqual(diff.days, 46*365 + len(range(1956, 2002, 4))) + self.assertEqual(diff.seconds, 0) + self.assertEqual(diff.microseconds, 0) + + day = timedelta(1) + week = timedelta(7) + a = self.theclass(2002, 3, 2) + self.assertEqual(a + day, self.theclass(2002, 3, 3)) + self.assertEqual(day + a, self.theclass(2002, 3, 3)) + self.assertEqual(a - day, self.theclass(2002, 3, 1)) + self.assertEqual(-day + a, self.theclass(2002, 3, 1)) + self.assertEqual(a + week, self.theclass(2002, 3, 9)) + self.assertEqual(a - week, self.theclass(2002, 2, 23)) + self.assertEqual(a + 52*week, self.theclass(2003, 3, 1)) + self.assertEqual(a - 52*week, self.theclass(2001, 3, 3)) + self.assertEqual((a + week) - a, week) + self.assertEqual((a + day) - a, day) + self.assertEqual((a - week) - a, -week) + self.assertEqual((a - day) - a, -day) + self.assertEqual(a - (a + week), -week) + self.assertEqual(a - (a + day), -day) + self.assertEqual(a - (a - week), week) + self.assertEqual(a - (a - day), day) + + # Add/sub ints, longs, floats should be illegal + for i in 1, 1L, 1.0: + self.assertRaises(TypeError, lambda: a+i) + self.assertRaises(TypeError, lambda: a-i) + self.assertRaises(TypeError, lambda: i+a) + self.assertRaises(TypeError, lambda: i-a) + + # delta - date is senseless. + self.assertRaises(TypeError, lambda: day - a) + # mixing date and (delta or date) via * or // is senseless + self.assertRaises(TypeError, lambda: day * a) + self.assertRaises(TypeError, lambda: a * day) + self.assertRaises(TypeError, lambda: day // a) + self.assertRaises(TypeError, lambda: a // day) + self.assertRaises(TypeError, lambda: a * a) + self.assertRaises(TypeError, lambda: a // a) + # date + date is senseless + self.assertRaises(TypeError, lambda: a + a) + + def test_overflow(self): + tiny = self.theclass.resolution + + for delta in [tiny, timedelta(1), timedelta(2)]: + dt = self.theclass.min + delta + dt -= delta # no problem + self.assertRaises(OverflowError, dt.__sub__, delta) + self.assertRaises(OverflowError, dt.__add__, -delta) + + dt = self.theclass.max - delta + dt += delta # no problem + self.assertRaises(OverflowError, dt.__add__, delta) + self.assertRaises(OverflowError, dt.__sub__, -delta) + + def test_fromtimestamp(self): + import time + + # Try an arbitrary fixed value. + year, month, day = 1999, 9, 19 + ts = time.mktime((year, month, day, 0, 0, 0, 0, 0, -1)) + d = self.theclass.fromtimestamp(ts) + self.assertEqual(d.year, year) + self.assertEqual(d.month, month) + self.assertEqual(d.day, day) + + def test_insane_fromtimestamp(self): + # It's possible that some platform maps time_t to double, + # and that this test will fail there. This test should + # exempt such platforms (provided they return reasonable + # results!). + for insane in -1e200, 1e200: + self.assertRaises(ValueError, self.theclass.fromtimestamp, + insane) + + def test_today(self): + import time + + # We claim that today() is like fromtimestamp(time.time()), so + # prove it. + for dummy in range(3): + today = self.theclass.today() + ts = time.time() + todayagain = self.theclass.fromtimestamp(ts) + if today == todayagain: + break + # There are several legit reasons that could fail: + # 1. It recently became midnight, between the today() and the + # time() calls. + # 2. The platform time() has such fine resolution that we'll + # never get the same value twice. + # 3. The platform time() has poor resolution, and we just + # happened to call today() right before a resolution quantum + # boundary. + # 4. The system clock got fiddled between calls. + # In any case, wait a little while and try again. + time.sleep(0.1) + + # It worked or it didn't. If it didn't, assume it's reason #2, and + # let the test pass if they're within half a second of each other. + if today != todayagain: + self.assertAlmostEqual(todayagain, today, + delta=timedelta(seconds=0.5)) + + def test_weekday(self): + for i in range(7): + # March 4, 2002 is a Monday + self.assertEqual(self.theclass(2002, 3, 4+i).weekday(), i) + self.assertEqual(self.theclass(2002, 3, 4+i).isoweekday(), i+1) + # January 2, 1956 is a Monday + self.assertEqual(self.theclass(1956, 1, 2+i).weekday(), i) + self.assertEqual(self.theclass(1956, 1, 2+i).isoweekday(), i+1) + + def test_isocalendar(self): + # Check examples from + # http://www.phys.uu.nl/~vgent/calendar/isocalendar.htm + for i in range(7): + d = self.theclass(2003, 12, 22+i) + self.assertEqual(d.isocalendar(), (2003, 52, i+1)) + d = self.theclass(2003, 12, 29) + timedelta(i) + self.assertEqual(d.isocalendar(), (2004, 1, i+1)) + d = self.theclass(2004, 1, 5+i) + self.assertEqual(d.isocalendar(), (2004, 2, i+1)) + d = self.theclass(2009, 12, 21+i) + self.assertEqual(d.isocalendar(), (2009, 52, i+1)) + d = self.theclass(2009, 12, 28) + timedelta(i) + self.assertEqual(d.isocalendar(), (2009, 53, i+1)) + d = self.theclass(2010, 1, 4+i) + self.assertEqual(d.isocalendar(), (2010, 1, i+1)) + + def test_iso_long_years(self): + # Calculate long ISO years and compare to table from + # http://www.phys.uu.nl/~vgent/calendar/isocalendar.htm + ISO_LONG_YEARS_TABLE = """ + 4 32 60 88 + 9 37 65 93 + 15 43 71 99 + 20 48 76 + 26 54 82 + + 105 133 161 189 + 111 139 167 195 + 116 144 172 + 122 150 178 + 128 156 184 + + 201 229 257 285 + 207 235 263 291 + 212 240 268 296 + 218 246 274 + 224 252 280 + + 303 331 359 387 + 308 336 364 392 + 314 342 370 398 + 320 348 376 + 325 353 381 + """ + iso_long_years = map(int, ISO_LONG_YEARS_TABLE.split()) + iso_long_years.sort() + L = [] + for i in range(400): + d = self.theclass(2000+i, 12, 31) + d1 = self.theclass(1600+i, 12, 31) + self.assertEqual(d.isocalendar()[1:], d1.isocalendar()[1:]) + if d.isocalendar()[1] == 53: + L.append(i) + self.assertEqual(L, iso_long_years) + + def test_isoformat(self): + t = self.theclass(2, 3, 2) + self.assertEqual(t.isoformat(), "0002-03-02") + + def test_ctime(self): + t = self.theclass(2002, 3, 2) + self.assertEqual(t.ctime(), "Sat Mar 2 00:00:00 2002") + + @unittest.expectedFailure + def test_strftime(self): + t = self.theclass(2005, 3, 2) + self.assertEqual(t.strftime("m:%m d:%d y:%y"), "m:03 d:02 y:05") + self.assertEqual(t.strftime(""), "") # SF bug #761337 + self.assertEqual(t.strftime('x'*1000), 'x'*1000) # SF bug #1556784 + + self.assertRaises(TypeError, t.strftime) # needs an arg + self.assertRaises(TypeError, t.strftime, "one", "two") # too many args + self.assertRaises(TypeError, t.strftime, 42) # arg wrong type + + # test that unicode input is allowed (issue 2782) + self.assertEqual(t.strftime(u"%m"), "03") + + # A naive object replaces %z and %Z w/ empty strings. + self.assertEqual(t.strftime("'%z' '%Z'"), "'' ''") + + #make sure that invalid format specifiers are handled correctly + #self.assertRaises(ValueError, t.strftime, "%e") + #self.assertRaises(ValueError, t.strftime, "%") + #self.assertRaises(ValueError, t.strftime, "%#") + + #oh well, some systems just ignore those invalid ones. + #at least, exercise them to make sure that no crashes + #are generated + for f in ["%e", "%", "%#"]: + try: + t.strftime(f) + except ValueError: + pass + + #check that this standard extension works + t.strftime("%f") + + + @unittest.expectedFailure + def test_format(self): + dt = self.theclass(2007, 9, 10) + self.assertEqual(dt.__format__(''), str(dt)) + + # check that a derived class's __str__() gets called + class A(self.theclass): + def __str__(self): + return 'A' + a = A(2007, 9, 10) + self.assertEqual(a.__format__(''), 'A') + + # check that a derived class's strftime gets called + class B(self.theclass): + def strftime(self, format_spec): + return 'B' + b = B(2007, 9, 10) + self.assertEqual(b.__format__(''), str(dt)) + + for fmt in ["m:%m d:%d y:%y", + "m:%m d:%d y:%y H:%H M:%M S:%S", + "%z %Z", + ]: + self.assertEqual(dt.__format__(fmt), dt.strftime(fmt)) + self.assertEqual(a.__format__(fmt), dt.strftime(fmt)) + self.assertEqual(b.__format__(fmt), 'B') + + def test_resolution_info(self): + self.assertIsInstance(self.theclass.min, self.theclass) + self.assertIsInstance(self.theclass.max, self.theclass) + self.assertIsInstance(self.theclass.resolution, timedelta) + self.assertTrue(self.theclass.max > self.theclass.min) + + def test_extreme_timedelta(self): + big = self.theclass.max - self.theclass.min + # 3652058 days, 23 hours, 59 minutes, 59 seconds, 999999 microseconds + n = (big.days*24*3600 + big.seconds)*1000000 + big.microseconds + # n == 315537897599999999 ~= 2**58.13 + justasbig = timedelta(0, 0, n) + self.assertEqual(big, justasbig) + self.assertEqual(self.theclass.min + big, self.theclass.max) + self.assertEqual(self.theclass.max - big, self.theclass.min) + + def test_timetuple(self): + for i in range(7): + # January 2, 1956 is a Monday (0) + d = self.theclass(1956, 1, 2+i) + t = d.timetuple() + self.assertEqual(t, (1956, 1, 2+i, 0, 0, 0, i, 2+i, -1)) + # February 1, 1956 is a Wednesday (2) + d = self.theclass(1956, 2, 1+i) + t = d.timetuple() + self.assertEqual(t, (1956, 2, 1+i, 0, 0, 0, (2+i)%7, 32+i, -1)) + # March 1, 1956 is a Thursday (3), and is the 31+29+1 = 61st day + # of the year. + d = self.theclass(1956, 3, 1+i) + t = d.timetuple() + self.assertEqual(t, (1956, 3, 1+i, 0, 0, 0, (3+i)%7, 61+i, -1)) + self.assertEqual(t.tm_year, 1956) + self.assertEqual(t.tm_mon, 3) + self.assertEqual(t.tm_mday, 1+i) + self.assertEqual(t.tm_hour, 0) + self.assertEqual(t.tm_min, 0) + self.assertEqual(t.tm_sec, 0) + self.assertEqual(t.tm_wday, (3+i)%7) + self.assertEqual(t.tm_yday, 61+i) + self.assertEqual(t.tm_isdst, -1) + + # def test_pickling(self): + # args = 6, 7, 23 + # orig = self.theclass(*args) + # for pickler, unpickler, proto in pickle_choices: + # green = pickler.dumps(orig, proto) + # derived = unpickler.loads(green) + # self.assertEqual(orig, derived) + + def test_compare(self): + t1 = self.theclass(2, 3, 4) + t2 = self.theclass(2, 3, 4) + self.assertTrue(t1 == t2) + self.assertTrue(t1 <= t2) + self.assertTrue(t1 >= t2) + self.assertFalse(t1 != t2) + self.assertFalse(t1 < t2) + self.assertFalse(t1 > t2) + self.assertEqual(cmp(t1, t2), 0) + self.assertEqual(cmp(t2, t1), 0) + + for args in (3, 3, 3), (2, 4, 4), (2, 3, 5): + t2 = self.theclass(*args) # this is larger than t1 + self.assertTrue(t1 < t2) + self.assertTrue(t2 > t1) + self.assertTrue(t1 <= t2) + self.assertTrue(t2 >= t1) + self.assertTrue(t1 != t2) + self.assertTrue(t2 != t1) + self.assertFalse(t1 == t2) + self.assertFalse(t2 == t1) + self.assertFalse(t1 > t2) + self.assertFalse(t2 < t1) + self.assertFalse(t1 >= t2) + self.assertFalse(t2 <= t1) + self.assertEqual(cmp(t1, t2), -1) + self.assertEqual(cmp(t2, t1), 1) + + for badarg in OTHERSTUFF: + self.assertEqual(t1 == badarg, False) + self.assertEqual(t1 != badarg, True) + self.assertEqual(badarg == t1, False) + self.assertEqual(badarg != t1, True) + + self.assertRaises(TypeError, lambda: t1 < badarg) + self.assertRaises(TypeError, lambda: t1 > badarg) + self.assertRaises(TypeError, lambda: t1 >= badarg) + self.assertRaises(TypeError, lambda: badarg <= t1) + self.assertRaises(TypeError, lambda: badarg < t1) + self.assertRaises(TypeError, lambda: badarg > t1) + self.assertRaises(TypeError, lambda: badarg >= t1) + + def test_mixed_compare(self): + our = self.theclass(2000, 4, 5) + self.assertRaises(TypeError, cmp, our, 1) + self.assertRaises(TypeError, cmp, 1, our) + + class AnotherDateTimeClass(object): + def __cmp__(self, other): + # Return "equal" so calling this can't be confused with + # compare-by-address (which never says "equal" for distinct + # objects). + return 0 + __hash__ = None # Silence Py3k warning + + # This still errors, because date and datetime comparison raise + # TypeError instead of NotImplemented when they don't know what to + # do, in order to stop comparison from falling back to the default + # compare-by-address. + their = AnotherDateTimeClass() + self.assertRaises(TypeError, cmp, our, their) + # Oops: The next stab raises TypeError in the C implementation, + # but not in the Python implementation of datetime. The difference + # is due to that the Python implementation defines __cmp__ but + # the C implementation defines tp_richcompare. This is more pain + # to fix than it's worth, so commenting out the test. + # self.assertEqual(cmp(their, our), 0) + + # But date and datetime comparison return NotImplemented instead if the + # other object has a timetuple attr. This gives the other object a + # chance to do the comparison. + class Comparable(AnotherDateTimeClass): + def timetuple(self): + return () + + their = Comparable() + self.assertEqual(cmp(our, their), 0) + self.assertEqual(cmp(their, our), 0) + self.assertTrue(our == their) + self.assertTrue(their == our) + + def test_bool(self): + # All dates are considered true. + self.assertTrue(self.theclass.min) + self.assertTrue(self.theclass.max) + + @unittest.expectedFailure + def test_strftime_out_of_range(self): + # For nasty technical reasons, we can't handle years before 1900. + cls = self.theclass + self.assertEqual(cls(1900, 1, 1).strftime("%Y"), "1900") + for y in 1, 49, 51, 99, 100, 1000, 1899: + self.assertRaises(ValueError, cls(y, 1, 1).strftime, "%Y") + + @unittest.expectedFailure + def test_replace(self): + cls = self.theclass + args = [1, 2, 3] + base = cls(*args) + self.assertEqual(base, base.replace()) + + i = 0 + for name, newval in (("year", 2), + ("month", 3), + ("day", 4)): + newargs = args[:] + newargs[i] = newval + expected = cls(*newargs) + got = base.replace(**{name: newval}) + self.assertEqual(expected, got) + i += 1 + + # Out of bounds. + base = cls(2000, 2, 29) + self.assertRaises(ValueError, base.replace, year=2001) + + @unittest.expectedFailure + def test_subclass_date(self): + + class C(self.theclass): + theAnswer = 42 + + def __new__(cls, *args, **kws): + temp = kws.copy() + extra = temp.pop('extra') + result = self.theclass.__new__(cls, *args, **temp) + result.extra = extra + return result + + def newmeth(self, start): + return start + self.year + self.month + + args = 2003, 4, 14 + + dt1 = self.theclass(*args) + dt2 = C(*args, **{'extra': 7}) + + self.assertEqual(dt2.__class__, C) + self.assertEqual(dt2.theAnswer, 42) + self.assertEqual(dt2.extra, 7) + self.assertEqual(dt1.toordinal(), dt2.toordinal()) + self.assertEqual(dt2.newmeth(-7), dt1.year + dt1.month - 7) + + # def test_pickling_subclass_date(self): + + # args = 6, 7, 23 + # orig = SubclassDate(*args) + # for pickler, unpickler, proto in pickle_choices: + # green = pickler.dumps(orig, proto) + # derived = unpickler.loads(green) + # self.assertEqual(orig, derived) + + # def test_backdoor_resistance(self): + # # For fast unpickling, the constructor accepts a pickle string. + # # This is a low-overhead backdoor. A user can (by intent or + # # mistake) pass a string directly, which (if it's the right length) + # # will get treated like a pickle, and bypass the normal sanity + # # checks in the constructor. This can create insane objects. + # # The constructor doesn't want to burn the time to validate all + # # fields, but does check the month field. This stops, e.g., + # # datetime.datetime('1995-03-25') from yielding an insane object. + # base = '1995-03-25' + # if not issubclass(self.theclass, datetime): + # base = base[:4] + # for month_byte in '9', chr(0), chr(13), '\xff': + # self.assertRaises(TypeError, self.theclass, + # base[:2] + month_byte + base[3:]) + # for ord_byte in range(1, 13): + # # This shouldn't blow up because of the month byte alone. If + # # the implementation changes to do more-careful checking, it may + # # blow up because other fields are insane. + # self.theclass(base[:2] + chr(ord_byte) + base[3:]) + +############################################################################# +# datetime tests + +class SubclassDatetime(datetime): + sub_var = 1 + +class TestDateTime(TestDate): + + theclass = datetime + + def test_basic_attributes(self): + dt = self.theclass(2002, 3, 1, 12, 0) + self.assertEqual(dt.year, 2002) + self.assertEqual(dt.month, 3) + self.assertEqual(dt.day, 1) + self.assertEqual(dt.hour, 12) + self.assertEqual(dt.minute, 0) + self.assertEqual(dt.second, 0) + self.assertEqual(dt.microsecond, 0) + + def test_basic_attributes_nonzero(self): + # Make sure all attributes are non-zero so bugs in + # bit-shifting access show up. + dt = self.theclass(2002, 3, 1, 12, 59, 59, 8000) + self.assertEqual(dt.year, 2002) + self.assertEqual(dt.month, 3) + self.assertEqual(dt.day, 1) + self.assertEqual(dt.hour, 12) + self.assertEqual(dt.minute, 59) + self.assertEqual(dt.second, 59) + self.assertEqual(dt.microsecond, 8000) + + @unittest.expectedFailure + def test_roundtrip(self): + for dt in (self.theclass(1, 2, 3, 4, 5, 6, 7), + self.theclass.now()): + # Verify dt -> string -> datetime identity. + s = repr(dt) + self.assertTrue(s.startswith('datetime.')) + s = s[9:] + dt2 = eval(s) + self.assertEqual(dt, dt2) + + # Verify identity via reconstructing from pieces. + dt2 = self.theclass(dt.year, dt.month, dt.day, + dt.hour, dt.minute, dt.second, + dt.microsecond) + self.assertEqual(dt, dt2) + + @unittest.expectedFailure + def test_isoformat(self): + t = self.theclass(2, 3, 2, 4, 5, 1, 123) + self.assertEqual(t.isoformat(), "0002-03-02T04:05:01.000123") + self.assertEqual(t.isoformat('T'), "0002-03-02T04:05:01.000123") + self.assertEqual(t.isoformat(' '), "0002-03-02 04:05:01.000123") + self.assertEqual(t.isoformat('\x00'), "0002-03-02\x0004:05:01.000123") + # str is ISO format with the separator forced to a blank. + self.assertEqual(str(t), "0002-03-02 04:05:01.000123") + + t = self.theclass(2, 3, 2) + self.assertEqual(t.isoformat(), "0002-03-02T00:00:00") + self.assertEqual(t.isoformat('T'), "0002-03-02T00:00:00") + self.assertEqual(t.isoformat(' '), "0002-03-02 00:00:00") + # str is ISO format with the separator forced to a blank. + self.assertEqual(str(t), "0002-03-02 00:00:00") + + @unittest.expectedFailure + def test_format(self): + dt = self.theclass(2007, 9, 10, 4, 5, 1, 123) + self.assertEqual(dt.__format__(''), str(dt)) + + # check that a derived class's __str__() gets called + class A(self.theclass): + def __str__(self): + return 'A' + a = A(2007, 9, 10, 4, 5, 1, 123) + self.assertEqual(a.__format__(''), 'A') + + # check that a derived class's strftime gets called + class B(self.theclass): + def strftime(self, format_spec): + return 'B' + b = B(2007, 9, 10, 4, 5, 1, 123) + self.assertEqual(b.__format__(''), str(dt)) + + for fmt in ["m:%m d:%d y:%y", + "m:%m d:%d y:%y H:%H M:%M S:%S", + "%z %Z", + ]: + self.assertEqual(dt.__format__(fmt), dt.strftime(fmt)) + self.assertEqual(a.__format__(fmt), dt.strftime(fmt)) + self.assertEqual(b.__format__(fmt), 'B') + + @unittest.expectedFailure + def test_more_ctime(self): + # Test fields that TestDate doesn't touch. + import time + + t = self.theclass(2002, 3, 2, 18, 3, 5, 123) + self.assertEqual(t.ctime(), "Sat Mar 2 18:03:05 2002") + # Oops! The next line fails on Win2K under MSVC 6, so it's commented + # out. The difference is that t.ctime() produces " 2" for the day, + # but platform ctime() produces "02" for the day. According to + # C99, t.ctime() is correct here. + # self.assertEqual(t.ctime(), time.ctime(time.mktime(t.timetuple()))) + + # So test a case where that difference doesn't matter. + t = self.theclass(2002, 3, 22, 18, 3, 5, 123) + self.assertEqual(t.ctime(), time.ctime(time.mktime(t.timetuple()))) + + def test_tz_independent_comparing(self): + dt1 = self.theclass(2002, 3, 1, 9, 0, 0) + dt2 = self.theclass(2002, 3, 1, 10, 0, 0) + dt3 = self.theclass(2002, 3, 1, 9, 0, 0) + self.assertEqual(dt1, dt3) + self.assertTrue(dt2 > dt3) + + # Make sure comparison doesn't forget microseconds, and isn't done + # via comparing a float timestamp (an IEEE double doesn't have enough + # precision to span microsecond resolution across years 1 thru 9999, + # so comparing via timestamp necessarily calls some distinct values + # equal). + dt1 = self.theclass(MAXYEAR, 12, 31, 23, 59, 59, 999998) + us = timedelta(microseconds=1) + dt2 = dt1 + us + self.assertEqual(dt2 - dt1, us) + self.assertTrue(dt1 < dt2) + + @unittest.expectedFailure + def test_strftime_with_bad_tzname_replace(self): + # verify ok if tzinfo.tzname().replace() returns a non-string + class MyTzInfo(FixedOffset): + def tzname(self, dt): + class MyStr(str): + def replace(self, *args): + return None + return MyStr('name') + t = self.theclass(2005, 3, 2, 0, 0, 0, 0, MyTzInfo(3, 'name')) + self.assertRaises(TypeError, t.strftime, '%Z') + + def test_bad_constructor_arguments(self): + # bad years + self.theclass(MINYEAR, 1, 1) # no exception + self.theclass(MAXYEAR, 1, 1) # no exception + self.assertRaises(ValueError, self.theclass, MINYEAR-1, 1, 1) + self.assertRaises(ValueError, self.theclass, MAXYEAR+1, 1, 1) + # bad months + self.theclass(2000, 1, 1) # no exception + self.theclass(2000, 12, 1) # no exception + self.assertRaises(ValueError, self.theclass, 2000, 0, 1) + self.assertRaises(ValueError, self.theclass, 2000, 13, 1) + # bad days + self.theclass(2000, 2, 29) # no exception + self.theclass(2004, 2, 29) # no exception + self.theclass(2400, 2, 29) # no exception + self.assertRaises(ValueError, self.theclass, 2000, 2, 30) + self.assertRaises(ValueError, self.theclass, 2001, 2, 29) + self.assertRaises(ValueError, self.theclass, 2100, 2, 29) + self.assertRaises(ValueError, self.theclass, 1900, 2, 29) + self.assertRaises(ValueError, self.theclass, 2000, 1, 0) + self.assertRaises(ValueError, self.theclass, 2000, 1, 32) + # bad hours + self.theclass(2000, 1, 31, 0) # no exception + self.theclass(2000, 1, 31, 23) # no exception + self.assertRaises(ValueError, self.theclass, 2000, 1, 31, -1) + self.assertRaises(ValueError, self.theclass, 2000, 1, 31, 24) + # bad minutes + self.theclass(2000, 1, 31, 23, 0) # no exception + self.theclass(2000, 1, 31, 23, 59) # no exception + self.assertRaises(ValueError, self.theclass, 2000, 1, 31, 23, -1) + self.assertRaises(ValueError, self.theclass, 2000, 1, 31, 23, 60) + # bad seconds + self.theclass(2000, 1, 31, 23, 59, 0) # no exception + self.theclass(2000, 1, 31, 23, 59, 59) # no exception + self.assertRaises(ValueError, self.theclass, 2000, 1, 31, 23, 59, -1) + self.assertRaises(ValueError, self.theclass, 2000, 1, 31, 23, 59, 60) + # bad microseconds + self.theclass(2000, 1, 31, 23, 59, 59, 0) # no exception + self.theclass(2000, 1, 31, 23, 59, 59, 999999) # no exception + self.assertRaises(ValueError, self.theclass, + 2000, 1, 31, 23, 59, 59, -1) + self.assertRaises(ValueError, self.theclass, + 2000, 1, 31, 23, 59, 59, + 1000000) + + def test_hash_equality(self): + d = self.theclass(2000, 12, 31, 23, 30, 17) + e = self.theclass(2000, 12, 31, 23, 30, 17) + self.assertEqual(d, e) + self.assertEqual(hash(d), hash(e)) + + dic = {d: 1} + dic[e] = 2 + self.assertEqual(len(dic), 1) + self.assertEqual(dic[d], 2) + self.assertEqual(dic[e], 2) + + d = self.theclass(2001, 1, 1, 0, 5, 17) + e = self.theclass(2001, 1, 1, 0, 5, 17) + self.assertEqual(d, e) + self.assertEqual(hash(d), hash(e)) + + dic = {d: 1} + dic[e] = 2 + self.assertEqual(len(dic), 1) + self.assertEqual(dic[d], 2) + self.assertEqual(dic[e], 2) + + def test_computations(self): + a = self.theclass(2002, 1, 31) + b = self.theclass(1956, 1, 31) + diff = a-b + self.assertEqual(diff.days, 46*365 + len(range(1956, 2002, 4))) + self.assertEqual(diff.seconds, 0) + self.assertEqual(diff.microseconds, 0) + a = self.theclass(2002, 3, 2, 17, 6) + millisec = timedelta(0, 0, 1000) + hour = timedelta(0, 3600) + day = timedelta(1) + week = timedelta(7) + self.assertEqual(a + hour, self.theclass(2002, 3, 2, 18, 6)) + self.assertEqual(hour + a, self.theclass(2002, 3, 2, 18, 6)) + self.assertEqual(a + 10*hour, self.theclass(2002, 3, 3, 3, 6)) + self.assertEqual(a - hour, self.theclass(2002, 3, 2, 16, 6)) + self.assertEqual(-hour + a, self.theclass(2002, 3, 2, 16, 6)) + self.assertEqual(a - hour, a + -hour) + self.assertEqual(a - 20*hour, self.theclass(2002, 3, 1, 21, 6)) + self.assertEqual(a + day, self.theclass(2002, 3, 3, 17, 6)) + self.assertEqual(a - day, self.theclass(2002, 3, 1, 17, 6)) + self.assertEqual(a + week, self.theclass(2002, 3, 9, 17, 6)) + self.assertEqual(a - week, self.theclass(2002, 2, 23, 17, 6)) + self.assertEqual(a + 52*week, self.theclass(2003, 3, 1, 17, 6)) + self.assertEqual(a - 52*week, self.theclass(2001, 3, 3, 17, 6)) + self.assertEqual((a + week) - a, week) + self.assertEqual((a + day) - a, day) + self.assertEqual((a + hour) - a, hour) + self.assertEqual((a + millisec) - a, millisec) + self.assertEqual((a - week) - a, -week) + self.assertEqual((a - day) - a, -day) + self.assertEqual((a - hour) - a, -hour) + self.assertEqual((a - millisec) - a, -millisec) + self.assertEqual(a - (a + week), -week) + self.assertEqual(a - (a + day), -day) + self.assertEqual(a - (a + hour), -hour) + self.assertEqual(a - (a + millisec), -millisec) + self.assertEqual(a - (a - week), week) + self.assertEqual(a - (a - day), day) + self.assertEqual(a - (a - hour), hour) + self.assertEqual(a - (a - millisec), millisec) + self.assertEqual(a + (week + day + hour + millisec), + self.theclass(2002, 3, 10, 18, 6, 0, 1000)) + self.assertEqual(a + (week + day + hour + millisec), + (((a + week) + day) + hour) + millisec) + self.assertEqual(a - (week + day + hour + millisec), + self.theclass(2002, 2, 22, 16, 5, 59, 999000)) + self.assertEqual(a - (week + day + hour + millisec), + (((a - week) - day) - hour) - millisec) + # Add/sub ints, longs, floats should be illegal + for i in 1, 1L, 1.0: + self.assertRaises(TypeError, lambda: a+i) + self.assertRaises(TypeError, lambda: a-i) + self.assertRaises(TypeError, lambda: i+a) + self.assertRaises(TypeError, lambda: i-a) + + # delta - datetime is senseless. + self.assertRaises(TypeError, lambda: day - a) + # mixing datetime and (delta or datetime) via * or // is senseless + self.assertRaises(TypeError, lambda: day * a) + self.assertRaises(TypeError, lambda: a * day) + self.assertRaises(TypeError, lambda: day // a) + self.assertRaises(TypeError, lambda: a // day) + self.assertRaises(TypeError, lambda: a * a) + self.assertRaises(TypeError, lambda: a // a) + # datetime + datetime is senseless + self.assertRaises(TypeError, lambda: a + a) + + # def test_pickling(self): + # args = 6, 7, 23, 20, 59, 1, 64**2 + # orig = self.theclass(*args) + # for pickler, unpickler, proto in pickle_choices: + # green = pickler.dumps(orig, proto) + # derived = unpickler.loads(green) + # self.assertEqual(orig, derived) + + # def test_more_pickling(self): + # a = self.theclass(2003, 2, 7, 16, 48, 37, 444116) + # for proto in range(pickle.HIGHEST_PROTOCOL + 1): + # s = pickle.dumps(a, proto) + # b = pickle.loads(s) + # self.assertEqual(b.year, 2003) + # self.assertEqual(b.month, 2) + # self.assertEqual(b.day, 7) + + # def test_pickling_subclass_datetime(self): + # args = 6, 7, 23, 20, 59, 1, 64**2 + # orig = SubclassDatetime(*args) + # for pickler, unpickler, proto in pickle_choices: + # green = pickler.dumps(orig, proto) + # derived = unpickler.loads(green) + # self.assertEqual(orig, derived) + + def test_more_compare(self): + # The test_compare() inherited from TestDate covers the error cases. + # We just want to test lexicographic ordering on the members datetime + # has that date lacks. + args = [2000, 11, 29, 20, 58, 16, 999998] + t1 = self.theclass(*args) + t2 = self.theclass(*args) + self.assertTrue(t1 == t2) + self.assertTrue(t1 <= t2) + self.assertTrue(t1 >= t2) + self.assertFalse(t1 != t2) + self.assertFalse(t1 < t2) + self.assertFalse(t1 > t2) + self.assertEqual(cmp(t1, t2), 0) + self.assertEqual(cmp(t2, t1), 0) + + for i in range(len(args)): + newargs = args[:] + newargs[i] = args[i] + 1 + t2 = self.theclass(*newargs) # this is larger than t1 + self.assertTrue(t1 < t2) + self.assertTrue(t2 > t1) + self.assertTrue(t1 <= t2) + self.assertTrue(t2 >= t1) + self.assertTrue(t1 != t2) + self.assertTrue(t2 != t1) + self.assertFalse(t1 == t2) + self.assertFalse(t2 == t1) + self.assertFalse(t1 > t2) + self.assertFalse(t2 < t1) + self.assertFalse(t1 >= t2) + self.assertFalse(t2 <= t1) + self.assertEqual(cmp(t1, t2), -1) + self.assertEqual(cmp(t2, t1), 1) + + + # A helper for timestamp constructor tests. + def verify_field_equality(self, expected, got): + self.assertEqual(expected.tm_year, got.year) + self.assertEqual(expected.tm_mon, got.month) + self.assertEqual(expected.tm_mday, got.day) + self.assertEqual(expected.tm_hour, got.hour) + self.assertEqual(expected.tm_min, got.minute) + self.assertEqual(expected.tm_sec, got.second) + + def test_fromtimestamp(self): + import time + + ts = time.time() + expected = time.localtime(ts) + got = self.theclass.fromtimestamp(ts) + self.verify_field_equality(expected, got) + + def test_utcfromtimestamp(self): + import time + + ts = time.time() + expected = time.gmtime(ts) + got = self.theclass.utcfromtimestamp(ts) + self.verify_field_equality(expected, got) + + def test_microsecond_rounding(self): + # Test whether fromtimestamp "rounds up" floats that are less + # than one microsecond smaller than an integer. + self.assertEqual(self.theclass.fromtimestamp(0.9999999), + self.theclass.fromtimestamp(1)) + + @unittest.expectedFailure + def test_insane_fromtimestamp(self): + # It's possible that some platform maps time_t to double, + # and that this test will fail there. This test should + # exempt such platforms (provided they return reasonable + # results!). + for insane in -1e200, 1e200: + self.assertRaises(ValueError, self.theclass.fromtimestamp, + insane) + + @unittest.expectedFailure + def test_insane_utcfromtimestamp(self): + # It's possible that some platform maps time_t to double, + # and that this test will fail there. This test should + # exempt such platforms (provided they return reasonable + # results!). + for insane in -1e200, 1e200: + self.assertRaises(ValueError, self.theclass.utcfromtimestamp, + insane) + + # @unittest.skipIf(sys.platform == "win32", "Windows doesn't accept negative timestamps") + # def test_negative_float_fromtimestamp(self): + # # The result is tz-dependent; at least test that this doesn't + # # fail (like it did before bug 1646728 was fixed). + # self.theclass.fromtimestamp(-1.05) + + # @unittest.skipIf(sys.platform == "win32", "Windows doesn't accept negative timestamps") + # def test_negative_float_utcfromtimestamp(self): + # d = self.theclass.utcfromtimestamp(-1.05) + # self.assertEqual(d, self.theclass(1969, 12, 31, 23, 59, 58, 950000)) + + def test_utcnow(self): + import time + + # Call it a success if utcnow() and utcfromtimestamp() are within + # a second of each other. + tolerance = timedelta(seconds=1) + for dummy in range(3): + from_now = self.theclass.utcnow() + from_timestamp = self.theclass.utcfromtimestamp(time.time()) + if abs(from_timestamp - from_now) <= tolerance: + break + # Else try again a few times. + self.assertLessEqual(abs(from_timestamp - from_now), tolerance) + + # def test_strptime(self): + # import _strptime + + # string = '2004-12-01 13:02:47.197' + # format = '%Y-%m-%d %H:%M:%S.%f' + # result, frac = _strptime._strptime(string, format) + # expected = self.theclass(*(result[0:6]+(frac,))) + # got = self.theclass.strptime(string, format) + # self.assertEqual(expected, got) + + def test_more_timetuple(self): + # This tests fields beyond those tested by the TestDate.test_timetuple. + t = self.theclass(2004, 12, 31, 6, 22, 33) + self.assertEqual(t.timetuple(), (2004, 12, 31, 6, 22, 33, 4, 366, -1)) + self.assertEqual(t.timetuple(), + (t.year, t.month, t.day, + t.hour, t.minute, t.second, + t.weekday(), + t.toordinal() - date(t.year, 1, 1).toordinal() + 1, + -1)) + tt = t.timetuple() + self.assertEqual(tt.tm_year, t.year) + self.assertEqual(tt.tm_mon, t.month) + self.assertEqual(tt.tm_mday, t.day) + self.assertEqual(tt.tm_hour, t.hour) + self.assertEqual(tt.tm_min, t.minute) + self.assertEqual(tt.tm_sec, t.second) + self.assertEqual(tt.tm_wday, t.weekday()) + self.assertEqual(tt.tm_yday, t.toordinal() - + date(t.year, 1, 1).toordinal() + 1) + self.assertEqual(tt.tm_isdst, -1) + + @unittest.expectedFailure + def test_more_strftime(self): + # This tests fields beyond those tested by the TestDate.test_strftime. + t = self.theclass(2004, 12, 31, 6, 22, 33, 47) + self.assertEqual(t.strftime("%m %d %y %f %S %M %H %j"), + "12 31 04 000047 33 22 06 366") + + def test_extract(self): + dt = self.theclass(2002, 3, 4, 18, 45, 3, 1234) + self.assertEqual(dt.date(), date(2002, 3, 4)) + self.assertEqual(dt.time(), time(18, 45, 3, 1234)) + + def test_combine(self): + d = date(2002, 3, 4) + t = time(18, 45, 3, 1234) + expected = self.theclass(2002, 3, 4, 18, 45, 3, 1234) + combine = self.theclass.combine + dt = combine(d, t) + self.assertEqual(dt, expected) + + dt = combine(time=t, date=d) + self.assertEqual(dt, expected) + + self.assertEqual(d, dt.date()) + self.assertEqual(t, dt.time()) + self.assertEqual(dt, combine(dt.date(), dt.time())) + + self.assertRaises(TypeError, combine) # need an arg + self.assertRaises(TypeError, combine, d) # need two args + self.assertRaises(TypeError, combine, t, d) # args reversed + self.assertRaises(TypeError, combine, d, t, 1) # too many args + self.assertRaises(TypeError, combine, "date", "time") # wrong types + + @unittest.expectedFailure + def test_replace(self): + cls = self.theclass + args = [1, 2, 3, 4, 5, 6, 7] + base = cls(*args) + self.assertEqual(base, base.replace()) + + i = 0 + for name, newval in (("year", 2), + ("month", 3), + ("day", 4), + ("hour", 5), + ("minute", 6), + ("second", 7), + ("microsecond", 8)): + newargs = args[:] + newargs[i] = newval + expected = cls(*newargs) + got = base.replace(**{name: newval}) + self.assertEqual(expected, got) + i += 1 + + # Out of bounds. + base = cls(2000, 2, 29) + self.assertRaises(ValueError, base.replace, year=2001) + + def test_astimezone(self): + # Pretty boring! The TZ test is more interesting here. astimezone() + # simply can't be applied to a naive object. + dt = self.theclass.now() + f = FixedOffset(44, "") + self.assertRaises(TypeError, dt.astimezone) # not enough args + self.assertRaises(TypeError, dt.astimezone, f, f) # too many args + self.assertRaises(TypeError, dt.astimezone, dt) # arg wrong type + self.assertRaises(ValueError, dt.astimezone, f) # naive + self.assertRaises(ValueError, dt.astimezone, tz=f) # naive + + class Bogus(tzinfo): + def utcoffset(self, dt): return None + def dst(self, dt): return timedelta(0) + bog = Bogus() + self.assertRaises(ValueError, dt.astimezone, bog) # naive + + class AlsoBogus(tzinfo): + def utcoffset(self, dt): return timedelta(0) + def dst(self, dt): return None + alsobog = AlsoBogus() + self.assertRaises(ValueError, dt.astimezone, alsobog) # also naive + + @unittest.expectedFailure + def test_subclass_datetime(self): + + class C(self.theclass): + theAnswer = 42 + + def __new__(cls, *args, **kws): + temp = kws.copy() + extra = temp.pop('extra') + result = self.theclass.__new__(cls, *args, **temp) + result.extra = extra + return result + + def newmeth(self, start): + return start + self.year + self.month + self.second + + args = 2003, 4, 14, 12, 13, 41 + + dt1 = self.theclass(*args) + dt2 = C(*args, **{'extra': 7}) + + self.assertEqual(dt2.__class__, C) + self.assertEqual(dt2.theAnswer, 42) + self.assertEqual(dt2.extra, 7) + self.assertEqual(dt1.toordinal(), dt2.toordinal()) + self.assertEqual(dt2.newmeth(-7), dt1.year + dt1.month + + dt1.second - 7) + +class SubclassTime(time): + sub_var = 1 + +class TestTime(HarmlessMixedComparison, unittest.TestCase): + + theclass = time + + def test_basic_attributes(self): + t = self.theclass(12, 0) + self.assertEqual(t.hour, 12) + self.assertEqual(t.minute, 0) + self.assertEqual(t.second, 0) + self.assertEqual(t.microsecond, 0) + + def test_basic_attributes_nonzero(self): + # Make sure all attributes are non-zero so bugs in + # bit-shifting access show up. + t = self.theclass(12, 59, 59, 8000) + self.assertEqual(t.hour, 12) + self.assertEqual(t.minute, 59) + self.assertEqual(t.second, 59) + self.assertEqual(t.microsecond, 8000) + + @unittest.expectedFailure + def test_roundtrip(self): + t = self.theclass(1, 2, 3, 4) + + # Verify t -> string -> time identity. + s = repr(t) + self.assertTrue(s.startswith('datetime.')) + s = s[9:] + t2 = eval(s) + self.assertEqual(t, t2) + + # Verify identity via reconstructing from pieces. + t2 = self.theclass(t.hour, t.minute, t.second, + t.microsecond) + self.assertEqual(t, t2) + + def test_comparing(self): + args = [1, 2, 3, 4] + t1 = self.theclass(*args) + t2 = self.theclass(*args) + self.assertTrue(t1 == t2) + self.assertTrue(t1 <= t2) + self.assertTrue(t1 >= t2) + self.assertFalse(t1 != t2) + self.assertFalse(t1 < t2) + self.assertFalse(t1 > t2) + self.assertEqual(cmp(t1, t2), 0) + self.assertEqual(cmp(t2, t1), 0) + + for i in range(len(args)): + newargs = args[:] + newargs[i] = args[i] + 1 + t2 = self.theclass(*newargs) # this is larger than t1 + self.assertTrue(t1 < t2) + self.assertTrue(t2 > t1) + self.assertTrue(t1 <= t2) + self.assertTrue(t2 >= t1) + self.assertTrue(t1 != t2) + self.assertTrue(t2 != t1) + self.assertFalse(t1 == t2) + self.assertFalse(t2 == t1) + self.assertFalse(t1 > t2) + self.assertFalse(t2 < t1) + self.assertFalse(t1 >= t2) + self.assertFalse(t2 <= t1) + self.assertEqual(cmp(t1, t2), -1) + self.assertEqual(cmp(t2, t1), 1) + + for badarg in OTHERSTUFF: + self.assertEqual(t1 == badarg, False) + self.assertEqual(t1 != badarg, True) + self.assertEqual(badarg == t1, False) + self.assertEqual(badarg != t1, True) + + self.assertRaises(TypeError, lambda: t1 <= badarg) + self.assertRaises(TypeError, lambda: t1 < badarg) + self.assertRaises(TypeError, lambda: t1 > badarg) + self.assertRaises(TypeError, lambda: t1 >= badarg) + self.assertRaises(TypeError, lambda: badarg <= t1) + self.assertRaises(TypeError, lambda: badarg < t1) + self.assertRaises(TypeError, lambda: badarg > t1) + self.assertRaises(TypeError, lambda: badarg >= t1) + + def test_bad_constructor_arguments(self): + # bad hours + self.theclass(0, 0) # no exception + self.theclass(23, 0) # no exception + self.assertRaises(ValueError, self.theclass, -1, 0) + self.assertRaises(ValueError, self.theclass, 24, 0) + # bad minutes + self.theclass(23, 0) # no exception + self.theclass(23, 59) # no exception + self.assertRaises(ValueError, self.theclass, 23, -1) + self.assertRaises(ValueError, self.theclass, 23, 60) + # bad seconds + self.theclass(23, 59, 0) # no exception + self.theclass(23, 59, 59) # no exception + self.assertRaises(ValueError, self.theclass, 23, 59, -1) + self.assertRaises(ValueError, self.theclass, 23, 59, 60) + # bad microseconds + self.theclass(23, 59, 59, 0) # no exception + self.theclass(23, 59, 59, 999999) # no exception + self.assertRaises(ValueError, self.theclass, 23, 59, 59, -1) + self.assertRaises(ValueError, self.theclass, 23, 59, 59, 1000000) + + def test_hash_equality(self): + d = self.theclass(23, 30, 17) + e = self.theclass(23, 30, 17) + self.assertEqual(d, e) + self.assertEqual(hash(d), hash(e)) + + dic = {d: 1} + dic[e] = 2 + self.assertEqual(len(dic), 1) + self.assertEqual(dic[d], 2) + self.assertEqual(dic[e], 2) + + d = self.theclass(0, 5, 17) + e = self.theclass(0, 5, 17) + self.assertEqual(d, e) + self.assertEqual(hash(d), hash(e)) + + dic = {d: 1} + dic[e] = 2 + self.assertEqual(len(dic), 1) + self.assertEqual(dic[d], 2) + self.assertEqual(dic[e], 2) + + def test_isoformat(self): + t = self.theclass(4, 5, 1, 123) + self.assertEqual(t.isoformat(), "04:05:01.000123") + self.assertEqual(t.isoformat(), str(t)) + + t = self.theclass() + self.assertEqual(t.isoformat(), "00:00:00") + self.assertEqual(t.isoformat(), str(t)) + + t = self.theclass(microsecond=1) + self.assertEqual(t.isoformat(), "00:00:00.000001") + self.assertEqual(t.isoformat(), str(t)) + + t = self.theclass(microsecond=10) + self.assertEqual(t.isoformat(), "00:00:00.000010") + self.assertEqual(t.isoformat(), str(t)) + + t = self.theclass(microsecond=100) + self.assertEqual(t.isoformat(), "00:00:00.000100") + self.assertEqual(t.isoformat(), str(t)) + + t = self.theclass(microsecond=1000) + self.assertEqual(t.isoformat(), "00:00:00.001000") + self.assertEqual(t.isoformat(), str(t)) + + t = self.theclass(microsecond=10000) + self.assertEqual(t.isoformat(), "00:00:00.010000") + self.assertEqual(t.isoformat(), str(t)) + + t = self.theclass(microsecond=100000) + self.assertEqual(t.isoformat(), "00:00:00.100000") + self.assertEqual(t.isoformat(), str(t)) + + def test_1653736(self): + # verify it doesn't accept extra keyword arguments + t = self.theclass(second=1) + self.assertRaises(TypeError, t.isoformat, foo=3) + + @unittest.expectedFailure + def test_strftime(self): + t = self.theclass(1, 2, 3, 4) + self.assertEqual(t.strftime('%H %M %S %f'), "01 02 03 000004") + # A naive object replaces %z and %Z with empty strings. + self.assertEqual(t.strftime("'%z' '%Z'"), "'' ''") + + @unittest.expectedFailure + def test_format(self): + t = self.theclass(1, 2, 3, 4) + self.assertEqual(t.__format__(''), str(t)) + + # check that a derived class's __str__() gets called + class A(self.theclass): + def __str__(self): + return 'A' + a = A(1, 2, 3, 4) + self.assertEqual(a.__format__(''), 'A') + + # check that a derived class's strftime gets called + class B(self.theclass): + def strftime(self, format_spec): + return 'B' + b = B(1, 2, 3, 4) + self.assertEqual(b.__format__(''), str(t)) + + for fmt in ['%H %M %S', + ]: + self.assertEqual(t.__format__(fmt), t.strftime(fmt)) + self.assertEqual(a.__format__(fmt), t.strftime(fmt)) + self.assertEqual(b.__format__(fmt), 'B') + + def test_str(self): + self.assertEqual(str(self.theclass(1, 2, 3, 4)), "01:02:03.000004") + self.assertEqual(str(self.theclass(10, 2, 3, 4000)), "10:02:03.004000") + self.assertEqual(str(self.theclass(0, 2, 3, 400000)), "00:02:03.400000") + self.assertEqual(str(self.theclass(12, 2, 3, 0)), "12:02:03") + self.assertEqual(str(self.theclass(23, 15, 0, 0)), "23:15:00") + + def test_repr(self): + name = 'datetime.' + self.theclass.__name__ + self.assertEqual(repr(self.theclass(1, 2, 3, 4)), + "%s(1, 2, 3, 4)" % name) + self.assertEqual(repr(self.theclass(10, 2, 3, 4000)), + "%s(10, 2, 3, 4000)" % name) + self.assertEqual(repr(self.theclass(0, 2, 3, 400000)), + "%s(0, 2, 3, 400000)" % name) + self.assertEqual(repr(self.theclass(12, 2, 3, 0)), + "%s(12, 2, 3)" % name) + self.assertEqual(repr(self.theclass(23, 15, 0, 0)), + "%s(23, 15)" % name) + + def test_resolution_info(self): + self.assertIsInstance(self.theclass.min, self.theclass) + self.assertIsInstance(self.theclass.max, self.theclass) + self.assertIsInstance(self.theclass.resolution, timedelta) + self.assertTrue(self.theclass.max > self.theclass.min) + + # def test_pickling(self): + # args = 20, 59, 16, 64**2 + # orig = self.theclass(*args) + # for pickler, unpickler, proto in pickle_choices: + # green = pickler.dumps(orig, proto) + # derived = unpickler.loads(green) + # self.assertEqual(orig, derived) + + # def test_pickling_subclass_time(self): + # args = 20, 59, 16, 64**2 + # orig = SubclassTime(*args) + # for pickler, unpickler, proto in pickle_choices: + # green = pickler.dumps(orig, proto) + # derived = unpickler.loads(green) + # self.assertEqual(orig, derived) + + def test_bool(self): + cls = self.theclass + self.assertTrue(cls(1)) + self.assertTrue(cls(0, 1)) + self.assertTrue(cls(0, 0, 1)) + self.assertTrue(cls(0, 0, 0, 1)) + self.assertFalse(cls(0)) + self.assertFalse(cls()) + + @unittest.expectedFailure + def test_replace(self): + cls = self.theclass + args = [1, 2, 3, 4] + base = cls(*args) + self.assertEqual(base, base.replace()) + + i = 0 + for name, newval in (("hour", 5), + ("minute", 6), + ("second", 7), + ("microsecond", 8)): + newargs = args[:] + newargs[i] = newval + expected = cls(*newargs) + got = base.replace(**{name: newval}) + self.assertEqual(expected, got) + i += 1 + + # Out of bounds. + base = cls(1) + self.assertRaises(ValueError, base.replace, hour=24) + self.assertRaises(ValueError, base.replace, minute=-1) + self.assertRaises(ValueError, base.replace, second=100) + self.assertRaises(ValueError, base.replace, microsecond=1000000) + + @unittest.expectedFailure + def test_subclass_time(self): + + class C(self.theclass): + theAnswer = 42 + + def __new__(cls, *args, **kws): + temp = kws.copy() + extra = temp.pop('extra') + result = self.theclass.__new__(cls, *args, **temp) + result.extra = extra + return result + + def newmeth(self, start): + return start + self.hour + self.second + + args = 4, 5, 6 + + dt1 = self.theclass(*args) + dt2 = C(*args, **{'extra': 7}) + + self.assertEqual(dt2.__class__, C) + self.assertEqual(dt2.theAnswer, 42) + self.assertEqual(dt2.extra, 7) + self.assertEqual(dt1.isoformat(), dt2.isoformat()) + self.assertEqual(dt2.newmeth(-7), dt1.hour + dt1.second - 7) + + def test_backdoor_resistance(self): + # see TestDate.test_backdoor_resistance(). + base = '2:59.0' + for hour_byte in ' ', '9', chr(24), '\xff': + self.assertRaises(TypeError, self.theclass, + hour_byte + base[1:]) + +# A mixin for classes with a tzinfo= argument. Subclasses must define +# theclass as a class attribute, and theclass(1, 1, 1, tzinfo=whatever) +# must be legit (which is true for time and datetime). +class TZInfoBase(object): + + def test_argument_passing(self): + cls = self.theclass + # A datetime passes itself on, a time passes None. + class introspective(tzinfo): + def tzname(self, dt): return dt and "real" or "none" + def utcoffset(self, dt): + return timedelta(minutes = dt and 42 or -42) + dst = utcoffset + + obj = cls(1, 2, 3, tzinfo=introspective()) + + expected = cls is time and "none" or "real" + self.assertEqual(obj.tzname(), expected) + + expected = timedelta(minutes=(cls is time and -42 or 42)) + self.assertEqual(obj.utcoffset(), expected) + self.assertEqual(obj.dst(), expected) + + def test_bad_tzinfo_classes(self): + cls = self.theclass + self.assertRaises(TypeError, cls, 1, 1, 1, tzinfo=12) + + class NiceTry(object): + def __init__(self): pass + def utcoffset(self, dt): pass + self.assertRaises(TypeError, cls, 1, 1, 1, tzinfo=NiceTry) + + class BetterTry(tzinfo): + def __init__(self): pass + def utcoffset(self, dt): pass + b = BetterTry() + t = cls(1, 1, 1, tzinfo=b) + self.assertIs(t.tzinfo, b) + + @unittest.skip('grumpy') + def test_utc_offset_out_of_bounds(self): + class Edgy(tzinfo): + def __init__(self, offset): + self.offset = timedelta(minutes=offset) + def utcoffset(self, dt): + return self.offset + + cls = self.theclass + for offset, legit in ((-1440, False), + (-1439, True), + (1439, True), + (1440, False)): + if cls is time: + t = cls(1, 2, 3, tzinfo=Edgy(offset)) + elif cls is datetime: + t = cls(6, 6, 6, 1, 2, 3, tzinfo=Edgy(offset)) + else: + assert 0, "impossible" + if legit: + aofs = abs(offset) + h, m = divmod(aofs, 60) + tag = "%c%02d:%02d" % (offset < 0 and '-' or '+', h, m) + if isinstance(t, datetime): + t = t.timetz() + self.assertEqual(str(t), "01:02:03" + tag) + else: + self.assertRaises(ValueError, str, t) + + def test_tzinfo_classes(self): + cls = self.theclass + class C1(tzinfo): + def utcoffset(self, dt): return None + def dst(self, dt): return None + def tzname(self, dt): return None + for t in (cls(1, 1, 1), + cls(1, 1, 1, tzinfo=None), + cls(1, 1, 1, tzinfo=C1())): + self.assertIsNone(t.utcoffset()) + self.assertIsNone(t.dst()) + self.assertIsNone(t.tzname()) + + class C3(tzinfo): + def utcoffset(self, dt): return timedelta(minutes=-1439) + def dst(self, dt): return timedelta(minutes=1439) + def tzname(self, dt): return "aname" + t = cls(1, 1, 1, tzinfo=C3()) + self.assertEqual(t.utcoffset(), timedelta(minutes=-1439)) + self.assertEqual(t.dst(), timedelta(minutes=1439)) + self.assertEqual(t.tzname(), "aname") + + # Wrong types. + class C4(tzinfo): + def utcoffset(self, dt): return "aname" + def dst(self, dt): return 7 + def tzname(self, dt): return 0 + t = cls(1, 1, 1, tzinfo=C4()) + self.assertRaises(TypeError, t.utcoffset) + self.assertRaises(TypeError, t.dst) + self.assertRaises(TypeError, t.tzname) + + # Offset out of range. + class C6(tzinfo): + def utcoffset(self, dt): return timedelta(hours=-24) + def dst(self, dt): return timedelta(hours=24) + t = cls(1, 1, 1, tzinfo=C6()) + self.assertRaises(ValueError, t.utcoffset) + self.assertRaises(ValueError, t.dst) + + # Not a whole number of minutes. + class C7(tzinfo): + def utcoffset(self, dt): return timedelta(seconds=61) + def dst(self, dt): return timedelta(microseconds=-81) + t = cls(1, 1, 1, tzinfo=C7()) + self.assertRaises(ValueError, t.utcoffset) + self.assertRaises(ValueError, t.dst) + + @unittest.skip('grumpy') + def test_aware_compare(self): + cls = self.theclass + + # Ensure that utcoffset() gets ignored if the comparands have + # the same tzinfo member. + class OperandDependentOffset(tzinfo): + def utcoffset(self, t): + if t.minute < 10: + # d0 and d1 equal after adjustment + return timedelta(minutes=t.minute) + else: + # d2 off in the weeds + return timedelta(minutes=59) + + base = cls(8, 9, 10, tzinfo=OperandDependentOffset()) + d0 = base.replace(minute=3) + d1 = base.replace(minute=9) + d2 = base.replace(minute=11) + for x in d0, d1, d2: + for y in d0, d1, d2: + got = cmp(x, y) + expected = cmp(x.minute, y.minute) + self.assertEqual(got, expected) + + # However, if they're different members, uctoffset is not ignored. + # Note that a time can't actually have an operand-depedent offset, + # though (and time.utcoffset() passes None to tzinfo.utcoffset()), + # so skip this test for time. + if cls is not time: + d0 = base.replace(minute=3, tzinfo=OperandDependentOffset()) + d1 = base.replace(minute=9, tzinfo=OperandDependentOffset()) + d2 = base.replace(minute=11, tzinfo=OperandDependentOffset()) + for x in d0, d1, d2: + for y in d0, d1, d2: + got = cmp(x, y) + if (x is d0 or x is d1) and (y is d0 or y is d1): + expected = 0 + elif x is y is d2: + expected = 0 + elif x is d2: + expected = -1 + else: + assert y is d2 + expected = 1 + self.assertEqual(got, expected) + + +# Testing time objects with a non-None tzinfo. +class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase): + theclass = time + + def test_empty(self): + t = self.theclass() + self.assertEqual(t.hour, 0) + self.assertEqual(t.minute, 0) + self.assertEqual(t.second, 0) + self.assertEqual(t.microsecond, 0) + self.assertIsNone(t.tzinfo) + + @unittest.expectedFailure + def test_zones(self): + est = FixedOffset(-300, "EST", 1) + utc = FixedOffset(0, "UTC", -2) + met = FixedOffset(60, "MET", 3) + t1 = time( 7, 47, tzinfo=est) + t2 = time(12, 47, tzinfo=utc) + t3 = time(13, 47, tzinfo=met) + t4 = time(microsecond=40) + t5 = time(microsecond=40, tzinfo=utc) + + self.assertEqual(t1.tzinfo, est) + self.assertEqual(t2.tzinfo, utc) + self.assertEqual(t3.tzinfo, met) + self.assertIsNone(t4.tzinfo) + self.assertEqual(t5.tzinfo, utc) + + self.assertEqual(t1.utcoffset(), timedelta(minutes=-300)) + self.assertEqual(t2.utcoffset(), timedelta(minutes=0)) + self.assertEqual(t3.utcoffset(), timedelta(minutes=60)) + self.assertIsNone(t4.utcoffset()) + self.assertRaises(TypeError, t1.utcoffset, "no args") + + self.assertEqual(t1.tzname(), "EST") + self.assertEqual(t2.tzname(), "UTC") + self.assertEqual(t3.tzname(), "MET") + self.assertIsNone(t4.tzname()) + self.assertRaises(TypeError, t1.tzname, "no args") + + self.assertEqual(t1.dst(), timedelta(minutes=1)) + self.assertEqual(t2.dst(), timedelta(minutes=-2)) + self.assertEqual(t3.dst(), timedelta(minutes=3)) + self.assertIsNone(t4.dst()) + self.assertRaises(TypeError, t1.dst, "no args") + + self.assertEqual(hash(t1), hash(t2)) + self.assertEqual(hash(t1), hash(t3)) + self.assertEqual(hash(t2), hash(t3)) + + self.assertEqual(t1, t2) + self.assertEqual(t1, t3) + self.assertEqual(t2, t3) + self.assertRaises(TypeError, lambda: t4 == t5) # mixed tz-aware & naive + self.assertRaises(TypeError, lambda: t4 < t5) # mixed tz-aware & naive + self.assertRaises(TypeError, lambda: t5 < t4) # mixed tz-aware & naive + + self.assertEqual(str(t1), "07:47:00-05:00") + self.assertEqual(str(t2), "12:47:00+00:00") + self.assertEqual(str(t3), "13:47:00+01:00") + self.assertEqual(str(t4), "00:00:00.000040") + self.assertEqual(str(t5), "00:00:00.000040+00:00") + + self.assertEqual(t1.isoformat(), "07:47:00-05:00") + self.assertEqual(t2.isoformat(), "12:47:00+00:00") + self.assertEqual(t3.isoformat(), "13:47:00+01:00") + self.assertEqual(t4.isoformat(), "00:00:00.000040") + self.assertEqual(t5.isoformat(), "00:00:00.000040+00:00") + + d = 'datetime.time' + self.assertEqual(repr(t1), d + "(7, 47, tzinfo=est)") + self.assertEqual(repr(t2), d + "(12, 47, tzinfo=utc)") + self.assertEqual(repr(t3), d + "(13, 47, tzinfo=met)") + self.assertEqual(repr(t4), d + "(0, 0, 0, 40)") + self.assertEqual(repr(t5), d + "(0, 0, 0, 40, tzinfo=utc)") + + self.assertEqual(t1.strftime("%H:%M:%S %%Z=%Z %%z=%z"), + "07:47:00 %Z=EST %z=-0500") + self.assertEqual(t2.strftime("%H:%M:%S %Z %z"), "12:47:00 UTC +0000") + self.assertEqual(t3.strftime("%H:%M:%S %Z %z"), "13:47:00 MET +0100") + + yuck = FixedOffset(-1439, "%z %Z %%z%%Z") + t1 = time(23, 59, tzinfo=yuck) + self.assertEqual(t1.strftime("%H:%M %%Z='%Z' %%z='%z'"), + "23:59 %Z='%z %Z %%z%%Z' %z='-2359'") + + # Check that an invalid tzname result raises an exception. + class Badtzname(tzinfo): + def tzname(self, dt): return 42 + t = time(2, 3, 4, tzinfo=Badtzname()) + self.assertEqual(t.strftime("%H:%M:%S"), "02:03:04") + self.assertRaises(TypeError, t.strftime, "%Z") + + @unittest.expectedFailure + def test_hash_edge_cases(self): + # Offsets that overflow a basic time. + t1 = self.theclass(0, 1, 2, 3, tzinfo=FixedOffset(1439, "")) + t2 = self.theclass(0, 0, 2, 3, tzinfo=FixedOffset(1438, "")) + self.assertEqual(hash(t1), hash(t2)) + + t1 = self.theclass(23, 58, 6, 100, tzinfo=FixedOffset(-1000, "")) + t2 = self.theclass(23, 48, 6, 100, tzinfo=FixedOffset(-1010, "")) + self.assertEqual(hash(t1), hash(t2)) + + # def test_pickling(self): + # # Try one without a tzinfo. + # args = 20, 59, 16, 64**2 + # orig = self.theclass(*args) + # for pickler, unpickler, proto in pickle_choices: + # green = pickler.dumps(orig, proto) + # derived = unpickler.loads(green) + # self.assertEqual(orig, derived) + + # # Try one with a tzinfo. + # tinfo = PicklableFixedOffset(-300, 'cookie') + # orig = self.theclass(5, 6, 7, tzinfo=tinfo) + # for pickler, unpickler, proto in pickle_choices: + # green = pickler.dumps(orig, proto) + # derived = unpickler.loads(green) + # self.assertEqual(orig, derived) + # self.assertIsInstance(derived.tzinfo, PicklableFixedOffset) + # self.assertEqual(derived.utcoffset(), timedelta(minutes=-300)) + # self.assertEqual(derived.tzname(), 'cookie') + + def test_more_bool(self): + # Test cases with non-None tzinfo. + cls = self.theclass + + t = cls(0, tzinfo=FixedOffset(-300, "")) + self.assertTrue(t) + + t = cls(5, tzinfo=FixedOffset(-300, "")) + self.assertTrue(t) + + t = cls(5, tzinfo=FixedOffset(300, "")) + self.assertFalse(t) + + t = cls(23, 59, tzinfo=FixedOffset(23*60 + 59, "")) + self.assertFalse(t) + + # Mostly ensuring this doesn't overflow internally. + t = cls(0, tzinfo=FixedOffset(23*60 + 59, "")) + self.assertTrue(t) + + # But this should yield a value error -- the utcoffset is bogus. + t = cls(0, tzinfo=FixedOffset(24*60, "")) + self.assertRaises(ValueError, lambda: bool(t)) + + # Likewise. + t = cls(0, tzinfo=FixedOffset(-24*60, "")) + self.assertRaises(ValueError, lambda: bool(t)) + + @unittest.expectedFailure + def test_replace(self): + cls = self.theclass + z100 = FixedOffset(100, "+100") + zm200 = FixedOffset(timedelta(minutes=-200), "-200") + args = [1, 2, 3, 4, z100] + base = cls(*args) + self.assertEqual(base, base.replace()) + + i = 0 + for name, newval in (("hour", 5), + ("minute", 6), + ("second", 7), + ("microsecond", 8), + ("tzinfo", zm200)): + newargs = args[:] + newargs[i] = newval + expected = cls(*newargs) + got = base.replace(**{name: newval}) + self.assertEqual(expected, got) + i += 1 + + # Ensure we can get rid of a tzinfo. + self.assertEqual(base.tzname(), "+100") + base2 = base.replace(tzinfo=None) + self.assertIsNone(base2.tzinfo) + self.assertIsNone(base2.tzname()) + + # Ensure we can add one. + base3 = base2.replace(tzinfo=z100) + self.assertEqual(base, base3) + self.assertIs(base.tzinfo, base3.tzinfo) + + # Out of bounds. + base = cls(1) + self.assertRaises(ValueError, base.replace, hour=24) + self.assertRaises(ValueError, base.replace, minute=-1) + self.assertRaises(ValueError, base.replace, second=100) + self.assertRaises(ValueError, base.replace, microsecond=1000000) + + @unittest.expectedFailure + def test_mixed_compare(self): + t1 = time(1, 2, 3) + t2 = time(1, 2, 3) + self.assertEqual(t1, t2) + t2 = t2.replace(tzinfo=None) + self.assertEqual(t1, t2) + t2 = t2.replace(tzinfo=FixedOffset(None, "")) + self.assertEqual(t1, t2) + t2 = t2.replace(tzinfo=FixedOffset(0, "")) + self.assertRaises(TypeError, lambda: t1 == t2) + + # In time w/ identical tzinfo objects, utcoffset is ignored. + class Varies(tzinfo): + def __init__(self): + self.offset = timedelta(minutes=22) + def utcoffset(self, t): + self.offset += timedelta(minutes=1) + return self.offset + + v = Varies() + t1 = t2.replace(tzinfo=v) + t2 = t2.replace(tzinfo=v) + self.assertEqual(t1.utcoffset(), timedelta(minutes=23)) + self.assertEqual(t2.utcoffset(), timedelta(minutes=24)) + self.assertEqual(t1, t2) + + # But if they're not identical, it isn't ignored. + t2 = t2.replace(tzinfo=Varies()) + self.assertTrue(t1 < t2) # t1's offset counter still going up + + @unittest.expectedFailure + def test_subclass_timetz(self): + + class C(self.theclass): + theAnswer = 42 + + def __new__(cls, *args, **kws): + temp = kws.copy() + extra = temp.pop('extra') + result = self.theclass.__new__(cls, *args, **temp) + result.extra = extra + return result + + def newmeth(self, start): + return start + self.hour + self.second + + args = 4, 5, 6, 500, FixedOffset(-300, "EST", 1) + + dt1 = self.theclass(*args) + dt2 = C(*args, **{'extra': 7}) + + self.assertEqual(dt2.__class__, C) + self.assertEqual(dt2.theAnswer, 42) + self.assertEqual(dt2.extra, 7) + self.assertEqual(dt1.utcoffset(), dt2.utcoffset()) + self.assertEqual(dt2.newmeth(-7), dt1.hour + dt1.second - 7) + + +# Testing datetime objects with a non-None tzinfo. + +class TestDateTimeTZ(TestDateTime, TZInfoBase, unittest.TestCase): + theclass = datetime + + def test_trivial(self): + dt = self.theclass(1, 2, 3, 4, 5, 6, 7) + self.assertEqual(dt.year, 1) + self.assertEqual(dt.month, 2) + self.assertEqual(dt.day, 3) + self.assertEqual(dt.hour, 4) + self.assertEqual(dt.minute, 5) + self.assertEqual(dt.second, 6) + self.assertEqual(dt.microsecond, 7) + self.assertEqual(dt.tzinfo, None) + + def test_even_more_compare(self): + # The test_compare() and test_more_compare() inherited from TestDate + # and TestDateTime covered non-tzinfo cases. + + # Smallest possible after UTC adjustment. + t1 = self.theclass(1, 1, 1, tzinfo=FixedOffset(1439, "")) + # Largest possible after UTC adjustment. + t2 = self.theclass(MAXYEAR, 12, 31, 23, 59, 59, 999999, + tzinfo=FixedOffset(-1439, "")) + + # Make sure those compare correctly, and w/o overflow. + self.assertTrue(t1 < t2) + self.assertTrue(t1 != t2) + self.assertTrue(t2 > t1) + + self.assertTrue(t1 == t1) + self.assertTrue(t2 == t2) + + # Equal afer adjustment. + t1 = self.theclass(1, 12, 31, 23, 59, tzinfo=FixedOffset(1, "")) + t2 = self.theclass(2, 1, 1, 3, 13, tzinfo=FixedOffset(3*60+13+2, "")) + self.assertEqual(t1, t2) + + # Change t1 not to subtract a minute, and t1 should be larger. + t1 = self.theclass(1, 12, 31, 23, 59, tzinfo=FixedOffset(0, "")) + self.assertTrue(t1 > t2) + + # Change t1 to subtract 2 minutes, and t1 should be smaller. + t1 = self.theclass(1, 12, 31, 23, 59, tzinfo=FixedOffset(2, "")) + self.assertTrue(t1 < t2) + + # Back to the original t1, but make seconds resolve it. + t1 = self.theclass(1, 12, 31, 23, 59, tzinfo=FixedOffset(1, ""), + second=1) + self.assertTrue(t1 > t2) + + # Likewise, but make microseconds resolve it. + t1 = self.theclass(1, 12, 31, 23, 59, tzinfo=FixedOffset(1, ""), + microsecond=1) + self.assertTrue(t1 > t2) + + # Make t2 naive and it should fail. + t2 = self.theclass.min + self.assertRaises(TypeError, lambda: t1 == t2) + self.assertEqual(t2, t2) + + # It's also naive if it has tzinfo but tzinfo.utcoffset() is None. + class Naive(tzinfo): + def utcoffset(self, dt): return None + t2 = self.theclass(5, 6, 7, tzinfo=Naive()) + self.assertRaises(TypeError, lambda: t1 == t2) + self.assertEqual(t2, t2) + + # OTOH, it's OK to compare two of these mixing the two ways of being + # naive. + t1 = self.theclass(5, 6, 7) + self.assertEqual(t1, t2) + + # Try a bogus uctoffset. + class Bogus(tzinfo): + def utcoffset(self, dt): + return timedelta(minutes=1440) # out of bounds + t1 = self.theclass(2, 2, 2, tzinfo=Bogus()) + t2 = self.theclass(2, 2, 2, tzinfo=FixedOffset(0, "")) + self.assertRaises(ValueError, lambda: t1 == t2) + + # def test_pickling(self): + # # Try one without a tzinfo. + # args = 6, 7, 23, 20, 59, 1, 64**2 + # orig = self.theclass(*args) + # for pickler, unpickler, proto in pickle_choices: + # green = pickler.dumps(orig, proto) + # derived = unpickler.loads(green) + # self.assertEqual(orig, derived) + + # # Try one with a tzinfo. + # tinfo = PicklableFixedOffset(-300, 'cookie') + # orig = self.theclass(*args, **{'tzinfo': tinfo}) + # derived = self.theclass(1, 1, 1, tzinfo=FixedOffset(0, "", 0)) + # for pickler, unpickler, proto in pickle_choices: + # green = pickler.dumps(orig, proto) + # derived = unpickler.loads(green) + # self.assertEqual(orig, derived) + # self.assertIsInstance(derived.tzinfo, PicklableFixedOffset) + # self.assertEqual(derived.utcoffset(), timedelta(minutes=-300)) + # self.assertEqual(derived.tzname(), 'cookie') + + def test_extreme_hashes(self): + # If an attempt is made to hash these via subtracting the offset + # then hashing a datetime object, OverflowError results. The + # Python implementation used to blow up here. + t = self.theclass(1, 1, 1, tzinfo=FixedOffset(1439, "")) + hash(t) + t = self.theclass(MAXYEAR, 12, 31, 23, 59, 59, 999999, + tzinfo=FixedOffset(-1439, "")) + hash(t) + + # OTOH, an OOB offset should blow up. + t = self.theclass(5, 5, 5, tzinfo=FixedOffset(-1440, "")) + self.assertRaises(ValueError, hash, t) + + @unittest.expectedFailure + def test_zones(self): + est = FixedOffset(-300, "EST") + utc = FixedOffset(0, "UTC") + met = FixedOffset(60, "MET") + t1 = datetime(2002, 3, 19, 7, 47, tzinfo=est) + t2 = datetime(2002, 3, 19, 12, 47, tzinfo=utc) + t3 = datetime(2002, 3, 19, 13, 47, tzinfo=met) + self.assertEqual(t1.tzinfo, est) + self.assertEqual(t2.tzinfo, utc) + self.assertEqual(t3.tzinfo, met) + self.assertEqual(t1.utcoffset(), timedelta(minutes=-300)) + self.assertEqual(t2.utcoffset(), timedelta(minutes=0)) + self.assertEqual(t3.utcoffset(), timedelta(minutes=60)) + self.assertEqual(t1.tzname(), "EST") + self.assertEqual(t2.tzname(), "UTC") + self.assertEqual(t3.tzname(), "MET") + self.assertEqual(hash(t1), hash(t2)) + self.assertEqual(hash(t1), hash(t3)) + self.assertEqual(hash(t2), hash(t3)) + self.assertEqual(t1, t2) + self.assertEqual(t1, t3) + self.assertEqual(t2, t3) + self.assertEqual(str(t1), "2002-03-19 07:47:00-05:00") + self.assertEqual(str(t2), "2002-03-19 12:47:00+00:00") + self.assertEqual(str(t3), "2002-03-19 13:47:00+01:00") + d = 'datetime.datetime(2002, 3, 19, ' + self.assertEqual(repr(t1), d + "7, 47, tzinfo=est)") + self.assertEqual(repr(t2), d + "12, 47, tzinfo=utc)") + self.assertEqual(repr(t3), d + "13, 47, tzinfo=met)") + + def test_combine(self): + met = FixedOffset(60, "MET") + d = date(2002, 3, 4) + tz = time(18, 45, 3, 1234, tzinfo=met) + dt = datetime.combine(d, tz) + self.assertEqual(dt, datetime(2002, 3, 4, 18, 45, 3, 1234, + tzinfo=met)) + + def test_extract(self): + met = FixedOffset(60, "MET") + dt = self.theclass(2002, 3, 4, 18, 45, 3, 1234, tzinfo=met) + self.assertEqual(dt.date(), date(2002, 3, 4)) + self.assertEqual(dt.time(), time(18, 45, 3, 1234)) + self.assertEqual(dt.timetz(), time(18, 45, 3, 1234, tzinfo=met)) + + @unittest.expectedFailure + def test_tz_aware_arithmetic(self): + import random + + now = self.theclass.now() + tz55 = FixedOffset(-330, "west 5:30") + timeaware = now.time().replace(tzinfo=tz55) + nowaware = self.theclass.combine(now.date(), timeaware) + self.assertIs(nowaware.tzinfo, tz55) + self.assertEqual(nowaware.timetz(), timeaware) + + # Can't mix aware and non-aware. + self.assertRaises(TypeError, lambda: now - nowaware) + self.assertRaises(TypeError, lambda: nowaware - now) + + # And adding datetime's doesn't make sense, aware or not. + self.assertRaises(TypeError, lambda: now + nowaware) + self.assertRaises(TypeError, lambda: nowaware + now) + self.assertRaises(TypeError, lambda: nowaware + nowaware) + + # Subtracting should yield 0. + self.assertEqual(now - now, timedelta(0)) + self.assertEqual(nowaware - nowaware, timedelta(0)) + + # Adding a delta should preserve tzinfo. + delta = timedelta(weeks=1, minutes=12, microseconds=5678) + nowawareplus = nowaware + delta + self.assertIs(nowaware.tzinfo, tz55) + nowawareplus2 = delta + nowaware + self.assertIs(nowawareplus2.tzinfo, tz55) + self.assertEqual(nowawareplus, nowawareplus2) + + # that - delta should be what we started with, and that - what we + # started with should be delta. + diff = nowawareplus - delta + self.assertIs(diff.tzinfo, tz55) + self.assertEqual(nowaware, diff) + self.assertRaises(TypeError, lambda: delta - nowawareplus) + self.assertEqual(nowawareplus - nowaware, delta) + + # Make up a random timezone. + tzr = FixedOffset(random.randrange(-1439, 1440), "randomtimezone") + # Attach it to nowawareplus. + nowawareplus = nowawareplus.replace(tzinfo=tzr) + self.assertIs(nowawareplus.tzinfo, tzr) + # Make sure the difference takes the timezone adjustments into account. + got = nowaware - nowawareplus + # Expected: (nowaware base - nowaware offset) - + # (nowawareplus base - nowawareplus offset) = + # (nowaware base - nowawareplus base) + + # (nowawareplus offset - nowaware offset) = + # -delta + nowawareplus offset - nowaware offset + expected = nowawareplus.utcoffset() - nowaware.utcoffset() - delta + self.assertEqual(got, expected) + + # Try max possible difference. + min = self.theclass(1, 1, 1, tzinfo=FixedOffset(1439, "min")) + max = self.theclass(MAXYEAR, 12, 31, 23, 59, 59, 999999, + tzinfo=FixedOffset(-1439, "max")) + maxdiff = max - min + self.assertEqual(maxdiff, self.theclass.max - self.theclass.min + + timedelta(minutes=2*1439)) + + @unittest.expectedFailure + def test_tzinfo_now(self): + meth = self.theclass.now + # Ensure it doesn't require tzinfo (i.e., that this doesn't blow up). + base = meth() + # Try with and without naming the keyword. + off42 = FixedOffset(42, "42") + another = meth(off42) + again = meth(tz=off42) + self.assertIs(another.tzinfo, again.tzinfo) + self.assertEqual(another.utcoffset(), timedelta(minutes=42)) + # Bad argument with and w/o naming the keyword. + self.assertRaises(TypeError, meth, 16) + self.assertRaises(TypeError, meth, tzinfo=16) + # Bad keyword name. + self.assertRaises(TypeError, meth, tinfo=off42) + # Too many args. + self.assertRaises(TypeError, meth, off42, off42) + + # We don't know which time zone we're in, and don't have a tzinfo + # class to represent it, so seeing whether a tz argument actually + # does a conversion is tricky. + weirdtz = FixedOffset(timedelta(hours=15, minutes=58), "weirdtz", 0) + utc = FixedOffset(0, "utc", 0) + for dummy in range(3): + now = datetime.now(weirdtz) + self.assertIs(now.tzinfo, weirdtz) + utcnow = datetime.utcnow().replace(tzinfo=utc) + now2 = utcnow.astimezone(weirdtz) + if abs(now - now2) < timedelta(seconds=30): + break + # Else the code is broken, or more than 30 seconds passed between + # calls; assuming the latter, just try again. + else: + # Three strikes and we're out. + self.fail("utcnow(), now(tz), or astimezone() may be broken") + + @unittest.expectedFailure + def test_tzinfo_fromtimestamp(self): + import time + meth = self.theclass.fromtimestamp + ts = time.time() + # Ensure it doesn't require tzinfo (i.e., that this doesn't blow up). + base = meth(ts) + # Try with and without naming the keyword. + off42 = FixedOffset(42, "42") + another = meth(ts, off42) + again = meth(ts, tz=off42) + self.assertIs(another.tzinfo, again.tzinfo) + self.assertEqual(another.utcoffset(), timedelta(minutes=42)) + # Bad argument with and w/o naming the keyword. + self.assertRaises(TypeError, meth, ts, 16) + self.assertRaises(TypeError, meth, ts, tzinfo=16) + # Bad keyword name. + self.assertRaises(TypeError, meth, ts, tinfo=off42) + # Too many args. + self.assertRaises(TypeError, meth, ts, off42, off42) + # Too few args. + self.assertRaises(TypeError, meth) + + # Try to make sure tz= actually does some conversion. + timestamp = 1000000000 + utcdatetime = datetime.utcfromtimestamp(timestamp) + # In POSIX (epoch 1970), that's 2001-09-09 01:46:40 UTC, give or take. + # But on some flavor of Mac, it's nowhere near that. So we can't have + # any idea here what time that actually is, we can only test that + # relative changes match. + utcoffset = timedelta(hours=-15, minutes=39) # arbitrary, but not zero + tz = FixedOffset(utcoffset, "tz", 0) + expected = utcdatetime + utcoffset + got = datetime.fromtimestamp(timestamp, tz) + self.assertEqual(expected, got.replace(tzinfo=None)) + + def test_tzinfo_utcnow(self): + meth = self.theclass.utcnow + # Ensure it doesn't require tzinfo (i.e., that this doesn't blow up). + base = meth() + # Try with and without naming the keyword; for whatever reason, + # utcnow() doesn't accept a tzinfo argument. + off42 = FixedOffset(42, "42") + self.assertRaises(TypeError, meth, off42) + self.assertRaises(TypeError, meth, tzinfo=off42) + + def test_tzinfo_utcfromtimestamp(self): + import time + meth = self.theclass.utcfromtimestamp + ts = time.time() + # Ensure it doesn't require tzinfo (i.e., that this doesn't blow up). + base = meth(ts) + # Try with and without naming the keyword; for whatever reason, + # utcfromtimestamp() doesn't accept a tzinfo argument. + off42 = FixedOffset(42, "42") + self.assertRaises(TypeError, meth, ts, off42) + self.assertRaises(TypeError, meth, ts, tzinfo=off42) + + def test_tzinfo_timetuple(self): + # TestDateTime tested most of this. datetime adds a twist to the + # DST flag. + class DST(tzinfo): + def __init__(self, dstvalue): + if isinstance(dstvalue, int): + dstvalue = timedelta(minutes=dstvalue) + self.dstvalue = dstvalue + def dst(self, dt): + return self.dstvalue + + cls = self.theclass + for dstvalue, flag in (-33, 1), (33, 1), (0, 0), (None, -1): + d = cls(1, 1, 1, 10, 20, 30, 40, tzinfo=DST(dstvalue)) + t = d.timetuple() + self.assertEqual(1, t.tm_year) + self.assertEqual(1, t.tm_mon) + self.assertEqual(1, t.tm_mday) + self.assertEqual(10, t.tm_hour) + self.assertEqual(20, t.tm_min) + self.assertEqual(30, t.tm_sec) + self.assertEqual(0, t.tm_wday) + self.assertEqual(1, t.tm_yday) + self.assertEqual(flag, t.tm_isdst) + + # dst() returns wrong type. + self.assertRaises(TypeError, cls(1, 1, 1, tzinfo=DST("x")).timetuple) + + # dst() at the edge. + self.assertEqual(cls(1,1,1, tzinfo=DST(1439)).timetuple().tm_isdst, 1) + self.assertEqual(cls(1,1,1, tzinfo=DST(-1439)).timetuple().tm_isdst, 1) + + # dst() out of range. + self.assertRaises(ValueError, cls(1,1,1, tzinfo=DST(1440)).timetuple) + self.assertRaises(ValueError, cls(1,1,1, tzinfo=DST(-1440)).timetuple) + + def test_utctimetuple(self): + class DST(tzinfo): + def __init__(self, dstvalue): + if isinstance(dstvalue, int): + dstvalue = timedelta(minutes=dstvalue) + self.dstvalue = dstvalue + def dst(self, dt): + return self.dstvalue + + cls = self.theclass + # This can't work: DST didn't implement utcoffset. + self.assertRaises(NotImplementedError, + cls(1, 1, 1, tzinfo=DST(0)).utcoffset) + + class UOFS(DST): + def __init__(self, uofs, dofs=None): + DST.__init__(self, dofs) + self.uofs = timedelta(minutes=uofs) + def utcoffset(self, dt): + return self.uofs + + # Ensure tm_isdst is 0 regardless of what dst() says: DST is never + # in effect for a UTC time. + for dstvalue in -33, 33, 0, None: + d = cls(1, 2, 3, 10, 20, 30, 40, tzinfo=UOFS(-53, dstvalue)) + t = d.utctimetuple() + self.assertEqual(d.year, t.tm_year) + self.assertEqual(d.month, t.tm_mon) + self.assertEqual(d.day, t.tm_mday) + self.assertEqual(11, t.tm_hour) # 20mm + 53mm = 1hn + 13mm + self.assertEqual(13, t.tm_min) + self.assertEqual(d.second, t.tm_sec) + self.assertEqual(d.weekday(), t.tm_wday) + self.assertEqual(d.toordinal() - date(1, 1, 1).toordinal() + 1, + t.tm_yday) + self.assertEqual(0, t.tm_isdst) + + # At the edges, UTC adjustment can normalize into years out-of-range + # for a datetime object. Ensure that a correct timetuple is + # created anyway. + tiny = cls(MINYEAR, 1, 1, 0, 0, 37, tzinfo=UOFS(1439)) + # That goes back 1 minute less than a full day. + t = tiny.utctimetuple() + self.assertEqual(t.tm_year, MINYEAR-1) + self.assertEqual(t.tm_mon, 12) + self.assertEqual(t.tm_mday, 31) + self.assertEqual(t.tm_hour, 0) + self.assertEqual(t.tm_min, 1) + self.assertEqual(t.tm_sec, 37) + self.assertEqual(t.tm_yday, 366) # "year 0" is a leap year + self.assertEqual(t.tm_isdst, 0) + + huge = cls(MAXYEAR, 12, 31, 23, 59, 37, 999999, tzinfo=UOFS(-1439)) + # That goes forward 1 minute less than a full day. + t = huge.utctimetuple() + self.assertEqual(t.tm_year, MAXYEAR+1) + self.assertEqual(t.tm_mon, 1) + self.assertEqual(t.tm_mday, 1) + self.assertEqual(t.tm_hour, 23) + self.assertEqual(t.tm_min, 58) + self.assertEqual(t.tm_sec, 37) + self.assertEqual(t.tm_yday, 1) + self.assertEqual(t.tm_isdst, 0) + + @unittest.expectedFailure + def test_tzinfo_isoformat(self): + zero = FixedOffset(0, "+00:00") + plus = FixedOffset(220, "+03:40") + minus = FixedOffset(-231, "-03:51") + unknown = FixedOffset(None, "") + + cls = self.theclass + datestr = '0001-02-03' + for ofs in None, zero, plus, minus, unknown: + for us in 0, 987001: + d = cls(1, 2, 3, 4, 5, 59, us, tzinfo=ofs) + timestr = '04:05:59' + (us and '.987001' or '') + ofsstr = ofs is not None and d.tzname() or '' + tailstr = timestr + ofsstr + iso = d.isoformat() + self.assertEqual(iso, datestr + 'T' + tailstr) + self.assertEqual(iso, d.isoformat('T')) + self.assertEqual(d.isoformat('k'), datestr + 'k' + tailstr) + self.assertEqual(str(d), datestr + ' ' + tailstr) + + @unittest.expectedFailure + def test_replace(self): + cls = self.theclass + z100 = FixedOffset(100, "+100") + zm200 = FixedOffset(timedelta(minutes=-200), "-200") + args = [1, 2, 3, 4, 5, 6, 7, z100] + base = cls(*args) + self.assertEqual(base, base.replace()) + + i = 0 + for name, newval in (("year", 2), + ("month", 3), + ("day", 4), + ("hour", 5), + ("minute", 6), + ("second", 7), + ("microsecond", 8), + ("tzinfo", zm200)): + newargs = args[:] + newargs[i] = newval + expected = cls(*newargs) + got = base.replace(**{name: newval}) + self.assertEqual(expected, got) + i += 1 + + # Ensure we can get rid of a tzinfo. + self.assertEqual(base.tzname(), "+100") + base2 = base.replace(tzinfo=None) + self.assertIsNone(base2.tzinfo) + self.assertIsNone(base2.tzname()) + + # Ensure we can add one. + base3 = base2.replace(tzinfo=z100) + self.assertEqual(base, base3) + self.assertIs(base.tzinfo, base3.tzinfo) + + # Out of bounds. + base = cls(2000, 2, 29) + self.assertRaises(ValueError, base.replace, year=2001) + + @unittest.expectedFailure + def test_more_astimezone(self): + # The inherited test_astimezone covered some trivial and error cases. + fnone = FixedOffset(None, "None") + f44m = FixedOffset(44, "44") + fm5h = FixedOffset(-timedelta(hours=5), "m300") + + dt = self.theclass.now(tz=f44m) + self.assertIs(dt.tzinfo, f44m) + # Replacing with degenerate tzinfo raises an exception. + self.assertRaises(ValueError, dt.astimezone, fnone) + # Ditto with None tz. + self.assertRaises(TypeError, dt.astimezone, None) + # Replacing with same tzinfo makes no change. + x = dt.astimezone(dt.tzinfo) + self.assertIs(x.tzinfo, f44m) + self.assertEqual(x.date(), dt.date()) + self.assertEqual(x.time(), dt.time()) + + # Replacing with different tzinfo does adjust. + got = dt.astimezone(fm5h) + self.assertIs(got.tzinfo, fm5h) + self.assertEqual(got.utcoffset(), timedelta(hours=-5)) + expected = dt - dt.utcoffset() # in effect, convert to UTC + expected += fm5h.utcoffset(dt) # and from there to local time + expected = expected.replace(tzinfo=fm5h) # and attach new tzinfo + self.assertEqual(got.date(), expected.date()) + self.assertEqual(got.time(), expected.time()) + self.assertEqual(got.timetz(), expected.timetz()) + self.assertIs(got.tzinfo, expected.tzinfo) + self.assertEqual(got, expected) + + @unittest.expectedFailure + def test_aware_subtract(self): + cls = self.theclass + + # Ensure that utcoffset() is ignored when the operands have the + # same tzinfo member. + class OperandDependentOffset(tzinfo): + def utcoffset(self, t): + if t.minute < 10: + # d0 and d1 equal after adjustment + return timedelta(minutes=t.minute) + else: + # d2 off in the weeds + return timedelta(minutes=59) + + base = cls(8, 9, 10, 11, 12, 13, 14, tzinfo=OperandDependentOffset()) + d0 = base.replace(minute=3) + d1 = base.replace(minute=9) + d2 = base.replace(minute=11) + for x in d0, d1, d2: + for y in d0, d1, d2: + got = x - y + expected = timedelta(minutes=x.minute - y.minute) + self.assertEqual(got, expected) + + # OTOH, if the tzinfo members are distinct, utcoffsets aren't + # ignored. + base = cls(8, 9, 10, 11, 12, 13, 14) + d0 = base.replace(minute=3, tzinfo=OperandDependentOffset()) + d1 = base.replace(minute=9, tzinfo=OperandDependentOffset()) + d2 = base.replace(minute=11, tzinfo=OperandDependentOffset()) + for x in d0, d1, d2: + for y in d0, d1, d2: + got = x - y + if (x is d0 or x is d1) and (y is d0 or y is d1): + expected = timedelta(0) + elif x is y is d2: + expected = timedelta(0) + elif x is d2: + expected = timedelta(minutes=(11-59)-0) + else: + assert y is d2 + expected = timedelta(minutes=0-(11-59)) + self.assertEqual(got, expected) + + @unittest.expectedFailure + def test_mixed_compare(self): + t1 = datetime(1, 2, 3, 4, 5, 6, 7) + t2 = datetime(1, 2, 3, 4, 5, 6, 7) + self.assertEqual(t1, t2) + t2 = t2.replace(tzinfo=None) + self.assertEqual(t1, t2) + t2 = t2.replace(tzinfo=FixedOffset(None, "")) + self.assertEqual(t1, t2) + t2 = t2.replace(tzinfo=FixedOffset(0, "")) + self.assertRaises(TypeError, lambda: t1 == t2) + + # In datetime w/ identical tzinfo objects, utcoffset is ignored. + class Varies(tzinfo): + def __init__(self): + self.offset = timedelta(minutes=22) + def utcoffset(self, t): + self.offset += timedelta(minutes=1) + return self.offset + + v = Varies() + t1 = t2.replace(tzinfo=v) + t2 = t2.replace(tzinfo=v) + self.assertEqual(t1.utcoffset(), timedelta(minutes=23)) + self.assertEqual(t2.utcoffset(), timedelta(minutes=24)) + self.assertEqual(t1, t2) + + # But if they're not identical, it isn't ignored. + t2 = t2.replace(tzinfo=Varies()) + self.assertTrue(t1 < t2) # t1's offset counter still going up + + @unittest.expectedFailure + def test_subclass_datetimetz(self): + + class C(self.theclass): + theAnswer = 42 + + def __new__(cls, *args, **kws): + temp = kws.copy() + extra = temp.pop('extra') + result = self.theclass.__new__(cls, *args, **temp) + result.extra = extra + return result + + def newmeth(self, start): + return start + self.hour + self.year + + args = 2002, 12, 31, 4, 5, 6, 500, FixedOffset(-300, "EST", 1) + + dt1 = self.theclass(*args) + dt2 = C(*args, **{'extra': 7}) + + self.assertEqual(dt2.__class__, C) + self.assertEqual(dt2.theAnswer, 42) + self.assertEqual(dt2.extra, 7) + self.assertEqual(dt1.utcoffset(), dt2.utcoffset()) + self.assertEqual(dt2.newmeth(-7), dt1.hour + dt1.year - 7) + +# Pain to set up DST-aware tzinfo classes. + +def first_sunday_on_or_after(dt): + days_to_go = 6 - dt.weekday() + if days_to_go: + dt += timedelta(days_to_go) + return dt + +ZERO = timedelta(0) +HOUR = timedelta(hours=1) +DAY = timedelta(days=1) +# In the US, DST starts at 2am (standard time) on the first Sunday in April. +DSTSTART = datetime(1, 4, 1, 2) +# and ends at 2am (DST time; 1am standard time) on the last Sunday of Oct, +# which is the first Sunday on or after Oct 25. Because we view 1:MM as +# being standard time on that day, there is no spelling in local time of +# the last hour of DST (that's 1:MM DST, but 1:MM is taken as standard time). +DSTEND = datetime(1, 10, 25, 1) + +class USTimeZone(tzinfo): + + def __init__(self, hours, reprname, stdname, dstname): + self.stdoffset = timedelta(hours=hours) + self.reprname = reprname + self.stdname = stdname + self.dstname = dstname + + def __repr__(self): + return self.reprname + + def tzname(self, dt): + if self.dst(dt): + return self.dstname + else: + return self.stdname + + def utcoffset(self, dt): + return self.stdoffset + self.dst(dt) + + def dst(self, dt): + if dt is None or dt.tzinfo is None: + # An exception instead may be sensible here, in one or more of + # the cases. + return ZERO + assert dt.tzinfo is self + + # Find first Sunday in April. + start = first_sunday_on_or_after(DSTSTART.replace(year=dt.year)) + assert start.weekday() == 6 and start.month == 4 and start.day <= 7 + + # Find last Sunday in October. + end = first_sunday_on_or_after(DSTEND.replace(year=dt.year)) + assert end.weekday() == 6 and end.month == 10 and end.day >= 25 + + # Can't compare naive to aware objects, so strip the timezone from + # dt first. + if start <= dt.replace(tzinfo=None) < end: + return HOUR + else: + return ZERO + +Eastern = USTimeZone(-5, "Eastern", "EST", "EDT") +Central = USTimeZone(-6, "Central", "CST", "CDT") +Mountain = USTimeZone(-7, "Mountain", "MST", "MDT") +Pacific = USTimeZone(-8, "Pacific", "PST", "PDT") +utc_real = FixedOffset(0, "UTC", 0) +# For better test coverage, we want another flavor of UTC that's west of +# the Eastern and Pacific timezones. +utc_fake = FixedOffset(-12*60, "UTCfake", 0) + +class TestTimezoneConversions(unittest.TestCase): + # The DST switch times for 2002, in std time. + dston = datetime(2002, 4, 7, 2) + dstoff = datetime(2002, 10, 27, 1) + + theclass = datetime + + # Check a time that's inside DST. + def checkinside(self, dt, tz, utc, dston, dstoff): + self.assertEqual(dt.dst(), HOUR) + + # Conversion to our own timezone is always an identity. + self.assertEqual(dt.astimezone(tz), dt) + + asutc = dt.astimezone(utc) + there_and_back = asutc.astimezone(tz) + + # Conversion to UTC and back isn't always an identity here, + # because there are redundant spellings (in local time) of + # UTC time when DST begins: the clock jumps from 1:59:59 + # to 3:00:00, and a local time of 2:MM:SS doesn't really + # make sense then. The classes above treat 2:MM:SS as + # daylight time then (it's "after 2am"), really an alias + # for 1:MM:SS standard time. The latter form is what + # conversion back from UTC produces. + if dt.date() == dston.date() and dt.hour == 2: + # We're in the redundant hour, and coming back from + # UTC gives the 1:MM:SS standard-time spelling. + self.assertEqual(there_and_back + HOUR, dt) + # Although during was considered to be in daylight + # time, there_and_back is not. + self.assertEqual(there_and_back.dst(), ZERO) + # They're the same times in UTC. + self.assertEqual(there_and_back.astimezone(utc), + dt.astimezone(utc)) + else: + # We're not in the redundant hour. + self.assertEqual(dt, there_and_back) + + # Because we have a redundant spelling when DST begins, there is + # (unfortunately) an hour when DST ends that can't be spelled at all in + # local time. When DST ends, the clock jumps from 1:59 back to 1:00 + # again. The hour 1:MM DST has no spelling then: 1:MM is taken to be + # standard time. 1:MM DST == 0:MM EST, but 0:MM is taken to be + # daylight time. The hour 1:MM daylight == 0:MM standard can't be + # expressed in local time. Nevertheless, we want conversion back + # from UTC to mimic the local clock's "repeat an hour" behavior. + nexthour_utc = asutc + HOUR + nexthour_tz = nexthour_utc.astimezone(tz) + if dt.date() == dstoff.date() and dt.hour == 0: + # We're in the hour before the last DST hour. The last DST hour + # is ineffable. We want the conversion back to repeat 1:MM. + self.assertEqual(nexthour_tz, dt.replace(hour=1)) + nexthour_utc += HOUR + nexthour_tz = nexthour_utc.astimezone(tz) + self.assertEqual(nexthour_tz, dt.replace(hour=1)) + else: + self.assertEqual(nexthour_tz - dt, HOUR) + + # Check a time that's outside DST. + def checkoutside(self, dt, tz, utc): + self.assertEqual(dt.dst(), ZERO) + + # Conversion to our own timezone is always an identity. + self.assertEqual(dt.astimezone(tz), dt) + + # Converting to UTC and back is an identity too. + asutc = dt.astimezone(utc) + there_and_back = asutc.astimezone(tz) + self.assertEqual(dt, there_and_back) + + def convert_between_tz_and_utc(self, tz, utc): + dston = self.dston.replace(tzinfo=tz) + # Because 1:MM on the day DST ends is taken as being standard time, + # there is no spelling in tz for the last hour of daylight time. + # For purposes of the test, the last hour of DST is 0:MM, which is + # taken as being daylight time (and 1:MM is taken as being standard + # time). + dstoff = self.dstoff.replace(tzinfo=tz) + for delta in (timedelta(weeks=13), + DAY, + HOUR, + timedelta(minutes=1), + timedelta(microseconds=1)): + + self.checkinside(dston, tz, utc, dston, dstoff) + for during in dston + delta, dstoff - delta: + self.checkinside(during, tz, utc, dston, dstoff) + + self.checkoutside(dstoff, tz, utc) + for outside in dston - delta, dstoff + delta: + self.checkoutside(outside, tz, utc) + + @unittest.expectedFailure + def test_easy(self): + # Despite the name of this test, the endcases are excruciating. + self.convert_between_tz_and_utc(Eastern, utc_real) + self.convert_between_tz_and_utc(Pacific, utc_real) + self.convert_between_tz_and_utc(Eastern, utc_fake) + self.convert_between_tz_and_utc(Pacific, utc_fake) + # The next is really dancing near the edge. It works because + # Pacific and Eastern are far enough apart that their "problem + # hours" don't overlap. + self.convert_between_tz_and_utc(Eastern, Pacific) + self.convert_between_tz_and_utc(Pacific, Eastern) + # OTOH, these fail! Don't enable them. The difficulty is that + # the edge case tests assume that every hour is representable in + # the "utc" class. This is always true for a fixed-offset tzinfo + # class (lke utc_real and utc_fake), but not for Eastern or Central. + # For these adjacent DST-aware time zones, the range of time offsets + # tested ends up creating hours in the one that aren't representable + # in the other. For the same reason, we would see failures in the + # Eastern vs Pacific tests too if we added 3*HOUR to the list of + # offset deltas in convert_between_tz_and_utc(). + # + # self.convert_between_tz_and_utc(Eastern, Central) # can't work + # self.convert_between_tz_and_utc(Central, Eastern) # can't work + + @unittest.expectedFailure + def test_tricky(self): + # 22:00 on day before daylight starts. + fourback = self.dston - timedelta(hours=4) + ninewest = FixedOffset(-9*60, "-0900", 0) + fourback = fourback.replace(tzinfo=ninewest) + # 22:00-0900 is 7:00 UTC == 2:00 EST == 3:00 DST. Since it's "after + # 2", we should get the 3 spelling. + # If we plug 22:00 the day before into Eastern, it "looks like std + # time", so its offset is returned as -5, and -5 - -9 = 4. Adding 4 + # to 22:00 lands on 2:00, which makes no sense in local time (the + # local clock jumps from 1 to 3). The point here is to make sure we + # get the 3 spelling. + expected = self.dston.replace(hour=3) + got = fourback.astimezone(Eastern).replace(tzinfo=None) + self.assertEqual(expected, got) + + # Similar, but map to 6:00 UTC == 1:00 EST == 2:00 DST. In that + # case we want the 1:00 spelling. + sixutc = self.dston.replace(hour=6, tzinfo=utc_real) + # Now 6:00 "looks like daylight", so the offset wrt Eastern is -4, + # and adding -4-0 == -4 gives the 2:00 spelling. We want the 1:00 EST + # spelling. + expected = self.dston.replace(hour=1) + got = sixutc.astimezone(Eastern).replace(tzinfo=None) + self.assertEqual(expected, got) + + # Now on the day DST ends, we want "repeat an hour" behavior. + # UTC 4:MM 5:MM 6:MM 7:MM checking these + # EST 23:MM 0:MM 1:MM 2:MM + # EDT 0:MM 1:MM 2:MM 3:MM + # wall 0:MM 1:MM 1:MM 2:MM against these + for utc in utc_real, utc_fake: + for tz in Eastern, Pacific: + first_std_hour = self.dstoff - timedelta(hours=2) # 23:MM + # Convert that to UTC. + first_std_hour -= tz.utcoffset(None) + # Adjust for possibly fake UTC. + asutc = first_std_hour + utc.utcoffset(None) + # First UTC hour to convert; this is 4:00 when utc=utc_real & + # tz=Eastern. + asutcbase = asutc.replace(tzinfo=utc) + for tzhour in (0, 1, 1, 2): + expectedbase = self.dstoff.replace(hour=tzhour) + for minute in 0, 30, 59: + expected = expectedbase.replace(minute=minute) + asutc = asutcbase.replace(minute=minute) + astz = asutc.astimezone(tz) + self.assertEqual(astz.replace(tzinfo=None), expected) + asutcbase += HOUR + + + @unittest.expectedFailure + def test_bogus_dst(self): + class ok(tzinfo): + def utcoffset(self, dt): return HOUR + def dst(self, dt): return HOUR + + now = self.theclass.now().replace(tzinfo=utc_real) + # Doesn't blow up. + now.astimezone(ok()) + + # Does blow up. + class notok(ok): + def dst(self, dt): return None + self.assertRaises(ValueError, now.astimezone, notok()) + + @unittest.expectedFailure + def test_fromutc(self): + self.assertRaises(TypeError, Eastern.fromutc) # not enough args + now = datetime.utcnow().replace(tzinfo=utc_real) + self.assertRaises(ValueError, Eastern.fromutc, now) # wrong tzinfo + now = now.replace(tzinfo=Eastern) # insert correct tzinfo + enow = Eastern.fromutc(now) # doesn't blow up + self.assertEqual(enow.tzinfo, Eastern) # has right tzinfo member + self.assertRaises(TypeError, Eastern.fromutc, now, now) # too many args + self.assertRaises(TypeError, Eastern.fromutc, date.today()) # wrong type + + # Always converts UTC to standard time. + class FauxUSTimeZone(USTimeZone): + def fromutc(self, dt): + return dt + self.stdoffset + FEastern = FauxUSTimeZone(-5, "FEastern", "FEST", "FEDT") + + # UTC 4:MM 5:MM 6:MM 7:MM 8:MM 9:MM + # EST 23:MM 0:MM 1:MM 2:MM 3:MM 4:MM + # EDT 0:MM 1:MM 2:MM 3:MM 4:MM 5:MM + + # Check around DST start. + start = self.dston.replace(hour=4, tzinfo=Eastern) + fstart = start.replace(tzinfo=FEastern) + for wall in 23, 0, 1, 3, 4, 5: + expected = start.replace(hour=wall) + if wall == 23: + expected -= timedelta(days=1) + got = Eastern.fromutc(start) + self.assertEqual(expected, got) + + expected = fstart + FEastern.stdoffset + got = FEastern.fromutc(fstart) + self.assertEqual(expected, got) + + # Ensure astimezone() calls fromutc() too. + got = fstart.replace(tzinfo=utc_real).astimezone(FEastern) + self.assertEqual(expected, got) + + start += HOUR + fstart += HOUR + + # Check around DST end. + start = self.dstoff.replace(hour=4, tzinfo=Eastern) + fstart = start.replace(tzinfo=FEastern) + for wall in 0, 1, 1, 2, 3, 4: + expected = start.replace(hour=wall) + got = Eastern.fromutc(start) + self.assertEqual(expected, got) + + expected = fstart + FEastern.stdoffset + got = FEastern.fromutc(fstart) + self.assertEqual(expected, got) + + # Ensure astimezone() calls fromutc() too. + got = fstart.replace(tzinfo=utc_real).astimezone(FEastern) + self.assertEqual(expected, got) + + start += HOUR + fstart += HOUR + + +############################################################################# +# oddballs + +class Oddballs(unittest.TestCase): + + @unittest.expectedFailure + def test_bug_1028306(self): + # Trying to compare a date to a datetime should act like a mixed- + # type comparison, despite that datetime is a subclass of date. + as_date = date.today() + as_datetime = datetime.combine(as_date, time()) + self.assertTrue(as_date != as_datetime) + self.assertTrue(as_datetime != as_date) + self.assertFalse(as_date == as_datetime) + self.assertFalse(as_datetime == as_date) + self.assertRaises(TypeError, lambda: as_date < as_datetime) + self.assertRaises(TypeError, lambda: as_datetime < as_date) + self.assertRaises(TypeError, lambda: as_date <= as_datetime) + self.assertRaises(TypeError, lambda: as_datetime <= as_date) + self.assertRaises(TypeError, lambda: as_date > as_datetime) + self.assertRaises(TypeError, lambda: as_datetime > as_date) + self.assertRaises(TypeError, lambda: as_date >= as_datetime) + self.assertRaises(TypeError, lambda: as_datetime >= as_date) + + # Neverthelss, comparison should work with the base-class (date) + # projection if use of a date method is forced. + self.assertTrue(as_date.__eq__(as_datetime)) + different_day = (as_date.day + 1) % 20 + 1 + self.assertFalse(as_date.__eq__(as_datetime.replace(day=different_day))) + + # And date should compare with other subclasses of date. If a + # subclass wants to stop this, it's up to the subclass to do so. + date_sc = SubclassDate(as_date.year, as_date.month, as_date.day) + self.assertEqual(as_date, date_sc) + self.assertEqual(date_sc, as_date) + + # Ditto for datetimes. + datetime_sc = SubclassDatetime(as_datetime.year, as_datetime.month, + as_date.day, 0, 0, 0) + self.assertEqual(as_datetime, datetime_sc) + self.assertEqual(datetime_sc, as_datetime) + +def test_main(): + test_support.run_unittest(__name__) + +if __name__ == "__main__": + test_main() diff --git a/third_party/stdlib/test/test_dict.py b/third_party/stdlib/test/test_dict.py index f7f03db7..9a4f3f9a 100644 --- a/third_party/stdlib/test/test_dict.py +++ b/third_party/stdlib/test/test_dict.py @@ -288,7 +288,6 @@ def test_get(self): self.assertRaises(TypeError, d.get) self.assertRaises(TypeError, d.get, None, None, None) - @unittest.expectedFailure def test_setdefault(self): # dict.setdefault() d = {} @@ -316,7 +315,6 @@ def __hash__(self): x.fail = True self.assertRaises(Exc, d.setdefault, x, []) - @unittest.expectedFailure def test_setdefault_atomic(self): # Issue #13521: setdefault() calls __hash__ and __eq__ only once. class Hashed(object): diff --git a/third_party/stdlib/test/test_dircache.py b/third_party/stdlib/test/test_dircache.py new file mode 100644 index 00000000..3926d659 --- /dev/null +++ b/third_party/stdlib/test/test_dircache.py @@ -0,0 +1,78 @@ +""" + Test cases for the dircache module + Nick Mathewson +""" + +import unittest +from test.test_support import run_unittest # , import_module +# dircache = import_module('dircache', deprecated=True) +import dircache +import os, time, sys, tempfile + + +class DircacheTests(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + for fname in os.listdir(self.tempdir): + self.delTemp(fname) + os.rmdir(self.tempdir) + + def writeTemp(self, fname): + f = open(os.path.join(self.tempdir, fname), 'w') + f.close() + + def mkdirTemp(self, fname): + os.mkdir(os.path.join(self.tempdir, fname)) + + def delTemp(self, fname): + fname = os.path.join(self.tempdir, fname) + if os.path.isdir(fname): + os.rmdir(fname) + else: + os.unlink(fname) + + def test_listdir(self): + ## SUCCESSFUL CASES + entries = dircache.listdir(self.tempdir) + self.assertEqual(entries, []) + + # Check that cache is actually caching, not just passing through. + self.assertTrue(dircache.listdir(self.tempdir) is entries) + + # Directories aren't "files" on Windows, and directory mtime has + # nothing to do with when files under a directory get created. + # That is, this test can't possibly work under Windows -- dircache + # is only good for capturing a one-shot snapshot there. + + if sys.platform[:3] not in ('win', 'os2'): + # Sadly, dircache has the same granularity as stat.mtime, and so + # can't notice any changes that occurred within 1 sec of the last + # time it examined a directory. + time.sleep(1) + self.writeTemp("test1") + entries = dircache.listdir(self.tempdir) + self.assertEqual(entries, ['test1']) + self.assertTrue(dircache.listdir(self.tempdir) is entries) + + ## UNSUCCESSFUL CASES + self.assertRaises(OSError, dircache.listdir, self.tempdir+"_nonexistent") + + def test_annotate(self): + self.writeTemp("test2") + self.mkdirTemp("A") + lst = ['A', 'test2', 'test_nonexistent'] + dircache.annotate(self.tempdir, lst) + self.assertEqual(lst, ['A/', 'test2', 'test_nonexistent']) + + +def test_main(): + try: + run_unittest(DircacheTests) + finally: + dircache.reset() + + +if __name__ == "__main__": + test_main() diff --git a/third_party/stdlib/test/test_dummy_thread.py b/third_party/stdlib/test/test_dummy_thread.py new file mode 100644 index 00000000..29a85315 --- /dev/null +++ b/third_party/stdlib/test/test_dummy_thread.py @@ -0,0 +1,182 @@ +"""Generic thread tests. + +Meant to be used by dummy_thread and thread. To allow for different modules +to be used, test_main() can be called with the module to use as the thread +implementation as its sole argument. + +""" +import dummy_thread as _thread +import time +import Queue +import random +import unittest +from test import test_support + +DELAY = 0 # Set > 0 when testing a module other than dummy_thread, such as + # the 'thread' module. + +class LockTests(unittest.TestCase): + """Test lock objects.""" + + def setUp(self): + # Create a lock + self.lock = _thread.allocate_lock() + + def test_initlock(self): + #Make sure locks start locked + self.assertFalse(self.lock.locked(), + "Lock object is not initialized unlocked.") + + def test_release(self): + # Test self.lock.release() + self.lock.acquire() + self.lock.release() + self.assertFalse(self.lock.locked(), + "Lock object did not release properly.") + + def test_improper_release(self): + #Make sure release of an unlocked thread raises _thread.error + self.assertRaises(_thread.error, self.lock.release) + + def test_cond_acquire_success(self): + #Make sure the conditional acquiring of the lock works. + self.assertTrue(self.lock.acquire(0), + "Conditional acquiring of the lock failed.") + + def test_cond_acquire_fail(self): + #Test acquiring locked lock returns False + self.lock.acquire(0) + self.assertFalse(self.lock.acquire(0), + "Conditional acquiring of a locked lock incorrectly " + "succeeded.") + + def test_uncond_acquire_success(self): + #Make sure unconditional acquiring of a lock works. + self.lock.acquire() + self.assertTrue(self.lock.locked(), + "Uncondional locking failed.") + + def test_uncond_acquire_return_val(self): + #Make sure that an unconditional locking returns True. + self.assertIs(self.lock.acquire(1), True, + "Unconditional locking did not return True.") + self.assertIs(self.lock.acquire(), True) + + def test_uncond_acquire_blocking(self): + #Make sure that unconditional acquiring of a locked lock blocks. + def delay_unlock(to_unlock, delay): + """Hold on to lock for a set amount of time before unlocking.""" + time.sleep(delay) + to_unlock.release() + + self.lock.acquire() + start_time = int(time.time()) + _thread.start_new_thread(delay_unlock,(self.lock, DELAY)) + if test_support.verbose: + print + print "*** Waiting for thread to release the lock "\ + "(approx. %s sec.) ***" % DELAY + self.lock.acquire() + end_time = int(time.time()) + if test_support.verbose: + print "done" + self.assertGreaterEqual(end_time - start_time, DELAY, + "Blocking by unconditional acquiring failed.") + +class MiscTests(unittest.TestCase): + """Miscellaneous tests.""" + + def test_exit(self): + #Make sure _thread.exit() raises SystemExit + self.assertRaises(SystemExit, _thread.exit) + + def test_ident(self): + #Test sanity of _thread.get_ident() + self.assertIsInstance(_thread.get_ident(), int, + "_thread.get_ident() returned a non-integer") + self.assertNotEqual(_thread.get_ident(), 0, + "_thread.get_ident() returned 0") + + def test_LockType(self): + #Make sure _thread.LockType is the same type as _thread.allocate_locke() + self.assertIsInstance(_thread.allocate_lock(), _thread.LockType, + "_thread.LockType is not an instance of what " + "is returned by _thread.allocate_lock()") + + def test_interrupt_main(self): + #Calling start_new_thread with a function that executes interrupt_main + # should raise KeyboardInterrupt upon completion. + def call_interrupt(): + _thread.interrupt_main() + self.assertRaises(KeyboardInterrupt, _thread.start_new_thread, + call_interrupt, tuple()) + + def test_interrupt_in_main(self): + # Make sure that if interrupt_main is called in main threat that + # KeyboardInterrupt is raised instantly. + self.assertRaises(KeyboardInterrupt, _thread.interrupt_main) + +class ThreadTests(unittest.TestCase): + """Test thread creation.""" + + def test_arg_passing(self): + #Make sure that parameter passing works. + def arg_tester(queue, arg1=False, arg2=False): + """Use to test _thread.start_new_thread() passes args properly.""" + queue.put((arg1, arg2)) + + testing_queue = Queue.Queue(1) + _thread.start_new_thread(arg_tester, (testing_queue, True, True)) + result = testing_queue.get() + self.assertTrue(result[0] and result[1], + "Argument passing for thread creation using tuple failed") + _thread.start_new_thread(arg_tester, tuple(), {'queue':testing_queue, + 'arg1':True, 'arg2':True}) + result = testing_queue.get() + self.assertTrue(result[0] and result[1], + "Argument passing for thread creation using kwargs failed") + _thread.start_new_thread(arg_tester, (testing_queue, True), {'arg2':True}) + result = testing_queue.get() + self.assertTrue(result[0] and result[1], + "Argument passing for thread creation using both tuple" + " and kwargs failed") + + def test_multi_creation(self): + #Make sure multiple threads can be created. + def queue_mark(queue, delay): + """Wait for ``delay`` seconds and then put something into ``queue``""" + time.sleep(delay) + queue.put(_thread.get_ident()) + + thread_count = 5 + testing_queue = Queue.Queue(thread_count) + if test_support.verbose: + print + print "*** Testing multiple thread creation "\ + "(will take approx. %s to %s sec.) ***" % (DELAY, thread_count) + for count in xrange(thread_count): + if DELAY: + local_delay = round(random.random(), 1) + else: + local_delay = 0 + _thread.start_new_thread(queue_mark, + (testing_queue, local_delay)) + time.sleep(DELAY) + if test_support.verbose: + print 'done' + self.assertEqual(testing_queue.qsize(), thread_count, + "Not all %s threads executed properly after %s sec." % + (thread_count, DELAY)) + +def test_main(imported_module=None): + global _thread, DELAY + if imported_module: + _thread = imported_module + DELAY = 2 + if test_support.verbose: + print + print "*** Using %s as _thread module ***" % _thread + test_support.run_unittest(LockTests, MiscTests, ThreadTests) + +if __name__ == '__main__': + test_main() diff --git a/third_party/stdlib/test/test_fpformat.py b/third_party/stdlib/test/test_fpformat.py new file mode 100644 index 00000000..c7b52e89 --- /dev/null +++ b/third_party/stdlib/test/test_fpformat.py @@ -0,0 +1,78 @@ +''' + Tests for fpformat module + Nick Mathewson +''' +from test.test_support import run_unittest #, import_module +import unittest +# fpformat = import_module('fpformat', deprecated=True) +import fpformat +fix, sci, NotANumber = fpformat.fix, fpformat.sci, fpformat.NotANumber + +StringType = type('') + +# Test the old and obsolescent fpformat module. +# +# (It's obsolescent because fix(n,d) == "%.*f"%(d,n) and +# sci(n,d) == "%.*e"%(d,n) +# for all reasonable numeric n and d, except that sci gives 3 exponent +# digits instead of 2. +# +# Differences only occur for unreasonable n and d. <.2 wink>) + +class FpformatTest(unittest.TestCase): + + def checkFix(self, n, digits): + result = fix(n, digits) + if isinstance(n, StringType): + n = repr(n) + expected = "%.*f" % (digits, float(n)) + + self.assertEqual(result, expected) + + def checkSci(self, n, digits): + result = sci(n, digits) + if isinstance(n, StringType): + n = repr(n) + expected = "%.*e" % (digits, float(n)) + # add the extra 0 if needed + num, exp = expected.split("e") + if len(exp) < 4: + exp = exp[0] + "0" + exp[1:] + expected = "%se%s" % (num, exp) + + self.assertEqual(result, expected) + + def test_basic_cases(self): + self.assertEqual(fix(100.0/3, 3), '33.333') + self.assertEqual(sci(100.0/3, 3), '3.333e+001') + + @unittest.skip('grumpy') + def test_reasonable_values(self): + for d in range(7): + for val in (1000.0/3, 1000, 1000.0, .002, 1.0/3, 1e10): + for realVal in (val, 1.0/val, -val, -1.0/val): + self.checkFix(realVal, d) + self.checkSci(realVal, d) + + def test_failing_values(self): + # Now for 'unreasonable n and d' + self.assertEqual(fix(1.0, 1000), '1.'+('0'*1000)) + self.assertEqual(sci("1"+('0'*1000), 0), '1e+1000') + + # This behavior is inconsistent. sci raises an exception; fix doesn't. + yacht = "Throatwobbler Mangrove" + self.assertEqual(fix(yacht, 10), yacht) + try: + sci(yacht, 10) + except NotANumber: + pass + else: + self.fail("No exception on non-numeric sci") + + +def test_main(): + run_unittest(FpformatTest) + + +if __name__ == "__main__": + test_main() diff --git a/third_party/stdlib/test/test_genericpath.py b/third_party/stdlib/test/test_genericpath.py new file mode 100644 index 00000000..03c283fe --- /dev/null +++ b/third_party/stdlib/test/test_genericpath.py @@ -0,0 +1,287 @@ +""" +Tests common to genericpath, macpath, ntpath and posixpath +""" + +import unittest +from test import test_support +import os +import genericpath +import sys + + +def safe_rmdir(dirname): + try: + os.rmdir(dirname) + except OSError: + pass + + +class GenericTest(unittest.TestCase): + # The path module to be tested + pathmodule = genericpath + common_attributes = ['commonprefix', 'getsize', 'getatime', 'getctime', + 'getmtime', 'exists', 'isdir', 'isfile'] + attributes = [] + + def test_no_argument(self): + for attr in self.common_attributes + self.attributes: + with self.assertRaises(TypeError): + getattr(self.pathmodule, attr)() + raise self.fail("{}.{}() did not raise a TypeError" + .format(self.pathmodule.__name__, attr)) + + def test_commonprefix(self): + commonprefix = self.pathmodule.commonprefix + self.assertEqual( + commonprefix([]), + "" + ) + self.assertEqual( + commonprefix(["/home/swenson/spam", "/home/swen/spam"]), + "/home/swen" + ) + self.assertEqual( + commonprefix(["/home/swen/spam", "/home/swen/eggs"]), + "/home/swen/" + ) + self.assertEqual( + commonprefix(["/home/swen/spam", "/home/swen/spam"]), + "/home/swen/spam" + ) + self.assertEqual( + commonprefix(["home:swenson:spam", "home:swen:spam"]), + "home:swen" + ) + self.assertEqual( + commonprefix([":home:swen:spam", ":home:swen:eggs"]), + ":home:swen:" + ) + self.assertEqual( + commonprefix([":home:swen:spam", ":home:swen:spam"]), + ":home:swen:spam" + ) + + testlist = ['', 'abc', 'Xbcd', 'Xb', 'XY', 'abcd', + 'aXc', 'abd', 'ab', 'aX', 'abcX'] + for s1 in testlist: + for s2 in testlist: + p = commonprefix([s1, s2]) + self.assertTrue(s1.startswith(p)) + self.assertTrue(s2.startswith(p)) + if s1 != s2: + n = len(p) + self.assertNotEqual(s1[n:n+1], s2[n:n+1]) + + def test_getsize(self): + f = open(test_support.TESTFN, "wb") + try: + f.write("foo") + f.close() + self.assertEqual(self.pathmodule.getsize(test_support.TESTFN), 3) + finally: + if not f.closed: + f.close() + test_support.unlink(test_support.TESTFN) + + @unittest.skip('grumpy') + def test_time(self): + f = open(test_support.TESTFN, "wb") + try: + f.write("foo") + f.close() + f = open(test_support.TESTFN, "ab") + f.write("bar") + f.close() + f = open(test_support.TESTFN, "rb") + d = f.read() + f.close() + self.assertEqual(d, "foobar") + + self.assertLessEqual( + self.pathmodule.getctime(test_support.TESTFN), + self.pathmodule.getmtime(test_support.TESTFN) + ) + finally: + if not f.closed: + f.close() + test_support.unlink(test_support.TESTFN) + + def test_exists(self): + self.assertIs(self.pathmodule.exists(test_support.TESTFN), False) + f = open(test_support.TESTFN, "wb") + try: + f.write("foo") + f.close() + self.assertIs(self.pathmodule.exists(test_support.TESTFN), True) + if not self.pathmodule == genericpath: + self.assertIs(self.pathmodule.lexists(test_support.TESTFN), + True) + finally: + if not f.close(): + f.close() + test_support.unlink(test_support.TESTFN) + + def test_isdir(self): + self.assertIs(self.pathmodule.isdir(test_support.TESTFN), False) + f = open(test_support.TESTFN, "wb") + try: + f.write("foo") + f.close() + self.assertIs(self.pathmodule.isdir(test_support.TESTFN), False) + os.remove(test_support.TESTFN) + os.mkdir(test_support.TESTFN) + self.assertIs(self.pathmodule.isdir(test_support.TESTFN), True) + os.rmdir(test_support.TESTFN) + finally: + if not f.close(): + f.close() + test_support.unlink(test_support.TESTFN) + safe_rmdir(test_support.TESTFN) + + @unittest.skip('grumpy') + def test_isfile(self): + self.assertIs(self.pathmodule.isfile(test_support.TESTFN), False) + f = open(test_support.TESTFN, "wb") + try: + f.write("foo") + f.close() + self.assertIs(self.pathmodule.isfile(test_support.TESTFN), True) + os.remove(test_support.TESTFN) + os.mkdir(test_support.TESTFN) + self.assertIs(self.pathmodule.isfile(test_support.TESTFN), False) + os.rmdir(test_support.TESTFN) + finally: + if not f.close(): + f.close() + test_support.unlink(test_support.TESTFN) + safe_rmdir(test_support.TESTFN) + + +# Following TestCase is not supposed to be run from test_genericpath. +# It is inherited by other test modules (macpath, ntpath, posixpath). + +class CommonTest(GenericTest): + # The path module to be tested + pathmodule = None + common_attributes = GenericTest.common_attributes + [ + # Properties + 'curdir', 'pardir', 'extsep', 'sep', + 'pathsep', 'defpath', 'altsep', 'devnull', + # Methods + 'normcase', 'splitdrive', 'expandvars', 'normpath', 'abspath', + 'join', 'split', 'splitext', 'isabs', 'basename', 'dirname', + 'lexists', 'islink', 'ismount', 'expanduser', 'normpath', 'realpath', + ] + + def test_normcase(self): + # Check that normcase() is idempotent + p = "FoO/./BaR" + p = self.pathmodule.normcase(p) + self.assertEqual(p, self.pathmodule.normcase(p)) + + def test_splitdrive(self): + # splitdrive for non-NT paths + splitdrive = self.pathmodule.splitdrive + self.assertEqual(splitdrive("/foo/bar"), ("", "/foo/bar")) + self.assertEqual(splitdrive("foo:bar"), ("", "foo:bar")) + self.assertEqual(splitdrive(":foo:bar"), ("", ":foo:bar")) + + def test_expandvars(self): + if self.pathmodule.__name__ == 'macpath': + self.skipTest('macpath.expandvars is a stub') + expandvars = self.pathmodule.expandvars + with test_support.EnvironmentVarGuard() as env: + env.clear() + env["foo"] = "bar" + env["{foo"] = "baz1" + env["{foo}"] = "baz2" + self.assertEqual(expandvars("foo"), "foo") + self.assertEqual(expandvars("$foo bar"), "bar bar") + self.assertEqual(expandvars("${foo}bar"), "barbar") + self.assertEqual(expandvars("$[foo]bar"), "$[foo]bar") + self.assertEqual(expandvars("$bar bar"), "$bar bar") + self.assertEqual(expandvars("$?bar"), "$?bar") + self.assertEqual(expandvars("$foo}bar"), "bar}bar") + self.assertEqual(expandvars("${foo"), "${foo") + self.assertEqual(expandvars("${{foo}}"), "baz1}") + self.assertEqual(expandvars("$foo$foo"), "barbar") + self.assertEqual(expandvars("$bar$bar"), "$bar$bar") + + @unittest.skipUnless(test_support.FS_NONASCII, 'need test_support.FS_NONASCII') + def test_expandvars_nonascii(self): + if self.pathmodule.__name__ == 'macpath': + self.skipTest('macpath.expandvars is a stub') + expandvars = self.pathmodule.expandvars + def check(value, expected): + self.assertEqual(expandvars(value), expected) + encoding = sys.getfilesystemencoding() + with test_support.EnvironmentVarGuard() as env: + env.clear() + unonascii = test_support.FS_NONASCII + snonascii = unonascii.encode(encoding) + env['spam'] = snonascii + env[snonascii] = 'ham' + snonascii + check(snonascii, snonascii) + check('$spam bar', '%s bar' % snonascii) + check('${spam}bar', '%sbar' % snonascii) + check('${%s}bar' % snonascii, 'ham%sbar' % snonascii) + check('$bar%s bar' % snonascii, '$bar%s bar' % snonascii) + check('$spam}bar', '%s}bar' % snonascii) + + check(unonascii, unonascii) + check(u'$spam bar', u'%s bar' % unonascii) + check(u'${spam}bar', u'%sbar' % unonascii) + check(u'${%s}bar' % unonascii, u'ham%sbar' % unonascii) + check(u'$bar%s bar' % unonascii, u'$bar%s bar' % unonascii) + check(u'$spam}bar', u'%s}bar' % unonascii) + + def test_abspath(self): + self.assertIn("foo", self.pathmodule.abspath("foo")) + + # Abspath returns bytes when the arg is bytes + for path in ('', 'foo', 'f\xf2\xf2', '/foo', 'C:\\'): + self.assertIsInstance(self.pathmodule.abspath(path), str) + + def test_realpath(self): + self.assertIn("foo", self.pathmodule.realpath("foo")) + + @test_support.requires_unicode + def test_normpath_issue5827(self): + # Make sure normpath preserves unicode + for path in (u'', u'.', u'/', u'\\', u'///foo/.//bar//'): + self.assertIsInstance(self.pathmodule.normpath(path), unicode) + + @test_support.requires_unicode + def test_abspath_issue3426(self): + # Check that abspath returns unicode when the arg is unicode + # with both ASCII and non-ASCII cwds. + abspath = self.pathmodule.abspath + for path in (u'', u'fuu', u'f\xf9\xf9', u'/fuu', u'U:\\'): + self.assertIsInstance(abspath(path), unicode) + + unicwd = u'\xe7w\xf0' + try: + fsencoding = test_support.TESTFN_ENCODING or "ascii" + unicwd.encode(fsencoding) + except (AttributeError, UnicodeEncodeError): + # FS encoding is probably ASCII + pass + else: + with test_support.temp_cwd(unicwd): + for path in (u'', u'fuu', u'f\xf9\xf9', u'/fuu', u'U:\\'): + self.assertIsInstance(abspath(path), unicode) + + @unittest.skipIf(sys.platform == 'darwin', + "Mac OS X denies the creation of a directory with an invalid utf8 name") + def test_nonascii_abspath(self): + # Test non-ASCII, non-UTF8 bytes in the path. + with test_support.temp_cwd('\xe7w\xf0'): + self.test_abspath() + + +def test_main(): + test_support.run_unittest(GenericTest) + + +if __name__=="__main__": + test_main() diff --git a/third_party/stdlib/test/test_md5.py b/third_party/stdlib/test/test_md5.py new file mode 100644 index 00000000..f03fdf8e --- /dev/null +++ b/third_party/stdlib/test/test_md5.py @@ -0,0 +1,63 @@ +# Testing md5 module +import warnings +warnings.filterwarnings("ignore", "the md5 module is deprecated.*", + DeprecationWarning) + +import unittest +# from md5 import md5 +import md5 as _md5 +md5 = _md5.md5 +from test import test_support + +def hexstr(s): + import string + h = string.hexdigits + r = '' + for c in s: + i = ord(c) + r = r + h[(i >> 4) & 0xF] + h[i & 0xF] + return r + +class MD5_Test(unittest.TestCase): + + def md5test(self, s, expected): + self.assertEqual(hexstr(md5(s).digest()), expected) + self.assertEqual(md5(s).hexdigest(), expected) + + def test_basics(self): + eq = self.md5test + eq('', 'd41d8cd98f00b204e9800998ecf8427e') + eq('a', '0cc175b9c0f1b6a831c399e269772661') + eq('abc', '900150983cd24fb0d6963f7d28e17f72') + eq('message digest', 'f96b697d7cb7938d525a2f31aaf161d0') + eq('abcdefghijklmnopqrstuvwxyz', 'c3fcd3d76192e4007dfb496cca67e13b') + eq('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789', + 'd174ab98d277d9f5a5611c2c9f419d9f') + eq('12345678901234567890123456789012345678901234567890123456789012345678901234567890', + '57edf4a22be3c955ac49da2e2107b67a') + + def test_hexdigest(self): + # hexdigest is new with Python 2.0 + m = md5('testing the hexdigest method') + h = m.hexdigest() + self.assertEqual(hexstr(m.digest()), h) + + def test_large_update(self): + aas = 'a' * 64 + bees = 'b' * 64 + cees = 'c' * 64 + + m1 = md5() + m1.update(aas) + m1.update(bees) + m1.update(cees) + + m2 = md5() + m2.update(aas + bees + cees) + self.assertEqual(m1.digest(), m2.digest()) + +def test_main(): + test_support.run_unittest(MD5_Test) + +if __name__ == '__main__': + test_main() diff --git a/third_party/stdlib/test/test_mimetools.py b/third_party/stdlib/test/test_mimetools.py new file mode 100644 index 00000000..e3745ee5 --- /dev/null +++ b/third_party/stdlib/test/test_mimetools.py @@ -0,0 +1,55 @@ +import unittest +from test import test_support + +import string +import StringIO + +#mimetools = test_support.import_module("mimetools", deprecated=True) +import mimetools + +msgtext1 = mimetools.Message(StringIO.StringIO( +"""Content-Type: text/plain; charset=iso-8859-1; format=flowed +Content-Transfer-Encoding: 8bit + +Foo! +""")) + +class MimeToolsTest(unittest.TestCase): + + def test_decodeencode(self): + start = string.ascii_letters + "=" + string.digits + "\n" + for enc in ['7bit','8bit','base64','quoted-printable', + 'uuencode', 'x-uuencode', 'uue', 'x-uue']: + i = StringIO.StringIO(start) + o = StringIO.StringIO() + mimetools.encode(i, o, enc) + i = StringIO.StringIO(o.getvalue()) + o = StringIO.StringIO() + mimetools.decode(i, o, enc) + self.assertEqual(o.getvalue(), start) + + @unittest.expectedFailure + def test_boundary(self): + s = set([""]) + for i in xrange(100): + nb = mimetools.choose_boundary() + self.assertNotIn(nb, s) + s.add(nb) + + def test_message(self): + msg = mimetools.Message(StringIO.StringIO(msgtext1)) + self.assertEqual(msg.gettype(), "text/plain") + self.assertEqual(msg.getmaintype(), "text") + self.assertEqual(msg.getsubtype(), "plain") + self.assertEqual(msg.getplist(), ["charset=iso-8859-1", "format=flowed"]) + self.assertEqual(msg.getparamnames(), ["charset", "format"]) + self.assertEqual(msg.getparam("charset"), "iso-8859-1") + self.assertEqual(msg.getparam("format"), "flowed") + self.assertEqual(msg.getparam("spam"), None) + self.assertEqual(msg.getencoding(), "8bit") + +def test_main(): + test_support.run_unittest(MimeToolsTest) + +if __name__=="__main__": + test_main() diff --git a/third_party/stdlib/test/test_mutex.py b/third_party/stdlib/test/test_mutex.py new file mode 100644 index 00000000..44dfffa4 --- /dev/null +++ b/third_party/stdlib/test/test_mutex.py @@ -0,0 +1,36 @@ +import unittest +import test.test_support + +# mutex = test.test_support.import_module("mutex", deprecated=True) +import mutex + +class MutexTest(unittest.TestCase): + + def test_lock_and_unlock(self): + + def called_by_mutex(some_data): + self.assertEqual(some_data, "spam") + self.assertTrue(m.test(), "mutex not held") + # Nested locking + m.lock(called_by_mutex2, "eggs") + + def called_by_mutex2(some_data): + self.assertEqual(some_data, "eggs") + self.assertTrue(m.test(), "mutex not held") + self.assertTrue(ready_for_2, + "called_by_mutex2 called too soon") + + m = mutex.mutex() + read_for_2 = False + m.lock(called_by_mutex, "spam") + ready_for_2 = True + # unlock both locks + m.unlock() + m.unlock() + self.assertFalse(m.test(), "mutex still held") + +def test_main(): + test.test_support.run_unittest(MutexTest) + +if __name__ == "__main__": + test_main() diff --git a/third_party/stdlib/test/test_queue.py b/third_party/stdlib/test/test_queue.py new file mode 100644 index 00000000..90292e2e --- /dev/null +++ b/third_party/stdlib/test/test_queue.py @@ -0,0 +1,326 @@ +# Some simple queue module tests, plus some failure conditions +# to ensure the Queue locks remain stable. +import Queue +import time +import unittest +from test import test_support +#threading = test_support.import_module('threading') +import threading + +QUEUE_SIZE = 5 + +# A thread to run a function that unclogs a blocked Queue. +class _TriggerThread(threading.Thread): + def __init__(self, fn, args): + self.fn = fn + self.args = args + self.startedEvent = threading.Event() + threading.Thread.__init__(self) + + def run(self): + # The sleep isn't necessary, but is intended to give the blocking + # function in the main thread a chance at actually blocking before + # we unclog it. But if the sleep is longer than the timeout-based + # tests wait in their blocking functions, those tests will fail. + # So we give them much longer timeout values compared to the + # sleep here (I aimed at 10 seconds for blocking functions -- + # they should never actually wait that long - they should make + # progress as soon as we call self.fn()). + time.sleep(0.1) + self.startedEvent.set() + self.fn(*self.args) + + +# Execute a function that blocks, and in a separate thread, a function that +# triggers the release. Returns the result of the blocking function. Caution: +# block_func must guarantee to block until trigger_func is called, and +# trigger_func must guarantee to change queue state so that block_func can make +# enough progress to return. In particular, a block_func that just raises an +# exception regardless of whether trigger_func is called will lead to +# timing-dependent sporadic failures, and one of those went rarely seen but +# undiagnosed for years. Now block_func must be unexceptional. If block_func +# is supposed to raise an exception, call do_exceptional_blocking_test() +# instead. + +class BlockingTestMixin(object): + + def tearDown(self): + self.t = None + + def do_blocking_test(self, block_func, block_args, trigger_func, trigger_args): + self.t = _TriggerThread(trigger_func, trigger_args) + self.t.start() + self.result = block_func(*block_args) + # If block_func returned before our thread made the call, we failed! + if not self.t.startedEvent.is_set(): + self.fail("blocking function '%r' appeared not to block" % + block_func) + self.t.join(10) # make sure the thread terminates + if self.t.is_alive(): + self.fail("trigger function '%r' appeared to not return" % + trigger_func) + return self.result + + # Call this instead if block_func is supposed to raise an exception. + def do_exceptional_blocking_test(self,block_func, block_args, trigger_func, + trigger_args, expected_exception_class): + self.t = _TriggerThread(trigger_func, trigger_args) + self.t.start() + try: + try: + block_func(*block_args) + except expected_exception_class: + raise + else: + self.fail("expected exception of kind %r" % + expected_exception_class) + finally: + self.t.join(10) # make sure the thread terminates + if self.t.is_alive(): + self.fail("trigger function '%r' appeared to not return" % + trigger_func) + if not self.t.startedEvent.is_set(): + self.fail("trigger thread ended but event never set") + + +class BaseQueueTest(BlockingTestMixin): + def setUp(self): + self.cum = 0 + self.cumlock = threading.Lock() + + def simple_queue_test(self, q): + if not q.empty(): + raise RuntimeError, "Call this function with an empty queue" + # I guess we better check things actually queue correctly a little :) + q.put(111) + q.put(333) + q.put(222) + target_order = dict(Queue = [111, 333, 222], + LifoQueue = [222, 333, 111], + PriorityQueue = [111, 222, 333]) + actual_order = [q.get(), q.get(), q.get()] + self.assertEqual(actual_order, target_order[q.__class__.__name__], + "Didn't seem to queue the correct data!") + for i in range(QUEUE_SIZE-1): + q.put(i) + self.assertTrue(not q.empty(), "Queue should not be empty") + self.assertTrue(not q.full(), "Queue should not be full") + last = 2 * QUEUE_SIZE + full = 3 * 2 * QUEUE_SIZE + q.put(last) + self.assertTrue(q.full(), "Queue should be full") + try: + q.put(full, block=0) + self.fail("Didn't appear to block with a full queue") + except Queue.Full: + pass + try: + q.put(full, timeout=0.01) + self.fail("Didn't appear to time-out with a full queue") + except Queue.Full: + pass + # Test a blocking put + self.do_blocking_test(q.put, (full,), q.get, ()) + self.do_blocking_test(q.put, (full, True, 10), q.get, ()) + # Empty it + for i in range(QUEUE_SIZE): + q.get() + self.assertTrue(q.empty(), "Queue should be empty") + try: + q.get(block=0) + self.fail("Didn't appear to block with an empty queue") + except Queue.Empty: + pass + try: + q.get(timeout=0.01) + self.fail("Didn't appear to time-out with an empty queue") + except Queue.Empty: + pass + # Test a blocking get + self.do_blocking_test(q.get, (), q.put, ('empty',)) + self.do_blocking_test(q.get, (True, 10), q.put, ('empty',)) + + + def worker(self, q): + while True: + x = q.get() + if x is None: + q.task_done() + return + with self.cumlock: + self.cum += x + q.task_done() + + def queue_join_test(self, q): + self.cum = 0 + for i in (0,1): + threading.Thread(target=self.worker, args=(q,)).start() + for i in xrange(100): + q.put(i) + q.join() + self.assertEqual(self.cum, sum(range(100)), + "q.join() did not block until all tasks were done") + for i in (0,1): + q.put(None) # instruct the threads to close + q.join() # verify that you can join twice + + def test_queue_task_done(self): + # Test to make sure a queue task completed successfully. + q = self.type2test() + try: + q.task_done() + except ValueError: + pass + else: + self.fail("Did not detect task count going negative") + + def test_queue_join(self): + # Test that a queue join()s successfully, and before anything else + # (done twice for insurance). + q = self.type2test() + self.queue_join_test(q) + self.queue_join_test(q) + try: + q.task_done() + except ValueError: + pass + else: + self.fail("Did not detect task count going negative") + + def test_simple_queue(self): + # Do it a couple of times on the same queue. + # Done twice to make sure works with same instance reused. + q = self.type2test(QUEUE_SIZE) + self.simple_queue_test(q) + self.simple_queue_test(q) + + +class QueueTest(BaseQueueTest, unittest.TestCase): + type2test = Queue.Queue + +class LifoQueueTest(BaseQueueTest, unittest.TestCase): + type2test = Queue.LifoQueue + +class PriorityQueueTest(BaseQueueTest, unittest.TestCase): + type2test = Queue.PriorityQueue + + + +# A Queue subclass that can provoke failure at a moment's notice :) +class FailingQueueException(Exception): + pass + +class FailingQueue(Queue.Queue): + def __init__(self, *args): + self.fail_next_put = False + self.fail_next_get = False + Queue.Queue.__init__(self, *args) + def _put(self, item): + if self.fail_next_put: + self.fail_next_put = False + raise FailingQueueException, "You Lose" + return Queue.Queue._put(self, item) + def _get(self): + if self.fail_next_get: + self.fail_next_get = False + raise FailingQueueException, "You Lose" + return Queue.Queue._get(self) + +class FailingQueueTest(BlockingTestMixin, unittest.TestCase): + + def failing_queue_test(self, q): + if not q.empty(): + raise RuntimeError, "Call this function with an empty queue" + for i in range(QUEUE_SIZE-1): + q.put(i) + # Test a failing non-blocking put. + q.fail_next_put = True + try: + q.put("oops", block=0) + self.fail("The queue didn't fail when it should have") + except FailingQueueException: + pass + q.fail_next_put = True + try: + q.put("oops", timeout=0.1) + self.fail("The queue didn't fail when it should have") + except FailingQueueException: + pass + q.put("last") + self.assertTrue(q.full(), "Queue should be full") + # Test a failing blocking put + q.fail_next_put = True + try: + self.do_blocking_test(q.put, ("full",), q.get, ()) + self.fail("The queue didn't fail when it should have") + except FailingQueueException: + pass + # Check the Queue isn't damaged. + # put failed, but get succeeded - re-add + q.put("last") + # Test a failing timeout put + q.fail_next_put = True + try: + self.do_exceptional_blocking_test(q.put, ("full", True, 10), q.get, (), + FailingQueueException) + self.fail("The queue didn't fail when it should have") + except FailingQueueException: + pass + # Check the Queue isn't damaged. + # put failed, but get succeeded - re-add + q.put("last") + self.assertTrue(q.full(), "Queue should be full") + q.get() + self.assertTrue(not q.full(), "Queue should not be full") + q.put("last") + self.assertTrue(q.full(), "Queue should be full") + # Test a blocking put + self.do_blocking_test(q.put, ("full",), q.get, ()) + # Empty it + for i in range(QUEUE_SIZE): + q.get() + self.assertTrue(q.empty(), "Queue should be empty") + q.put("first") + q.fail_next_get = True + try: + q.get() + self.fail("The queue didn't fail when it should have") + except FailingQueueException: + pass + self.assertTrue(not q.empty(), "Queue should not be empty") + q.fail_next_get = True + try: + q.get(timeout=0.1) + self.fail("The queue didn't fail when it should have") + except FailingQueueException: + pass + self.assertTrue(not q.empty(), "Queue should not be empty") + q.get() + self.assertTrue(q.empty(), "Queue should be empty") + q.fail_next_get = True + try: + self.do_exceptional_blocking_test(q.get, (), q.put, ('empty',), + FailingQueueException) + self.fail("The queue didn't fail when it should have") + except FailingQueueException: + pass + # put succeeded, but get failed. + self.assertTrue(not q.empty(), "Queue should not be empty") + q.get() + self.assertTrue(q.empty(), "Queue should be empty") + + def test_failing_queue(self): + # Test to make sure a queue is functioning correctly. + # Done twice to the same instance. + q = FailingQueue(QUEUE_SIZE) + self.failing_queue_test(q) + self.failing_queue_test(q) + + +def test_main(): + test_support.run_unittest(QueueTest, LifoQueueTest, PriorityQueueTest, + FailingQueueTest) + + +if __name__ == "__main__": + test_main() diff --git a/third_party/stdlib/test/test_quopri.py b/third_party/stdlib/test/test_quopri.py new file mode 100644 index 00000000..c179a263 --- /dev/null +++ b/third_party/stdlib/test/test_quopri.py @@ -0,0 +1,203 @@ +from test import test_support +import unittest + +import sys, cStringIO #, subprocess +import quopri + + + +ENCSAMPLE = """\ +Here's a bunch of special=20 + +=A1=A2=A3=A4=A5=A6=A7=A8=A9 +=AA=AB=AC=AD=AE=AF=B0=B1=B2=B3 +=B4=B5=B6=B7=B8=B9=BA=BB=BC=BD=BE +=BF=C0=C1=C2=C3=C4=C5=C6 +=C7=C8=C9=CA=CB=CC=CD=CE=CF +=D0=D1=D2=D3=D4=D5=D6=D7 +=D8=D9=DA=DB=DC=DD=DE=DF +=E0=E1=E2=E3=E4=E5=E6=E7 +=E8=E9=EA=EB=EC=ED=EE=EF +=F0=F1=F2=F3=F4=F5=F6=F7 +=F8=F9=FA=FB=FC=FD=FE=FF + +characters... have fun! +""" + +# First line ends with a space +DECSAMPLE = "Here's a bunch of special \n" + \ +"""\ + +\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9 +\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3 +\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe +\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6 +\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf +\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7 +\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf +\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7 +\xe8\xe9\xea\xeb\xec\xed\xee\xef +\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7 +\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff + +characters... have fun! +""" + + +def withpythonimplementation(testfunc): + def newtest(self): + # Test default implementation + testfunc(self) + # Test Python implementation + if quopri.b2a_qp is not None or quopri.a2b_qp is not None: + oldencode = quopri.b2a_qp + olddecode = quopri.a2b_qp + try: + quopri.b2a_qp = None + quopri.a2b_qp = None + testfunc(self) + finally: + quopri.b2a_qp = oldencode + quopri.a2b_qp = olddecode + #newtest.__name__ = testfunc.__name__ + return newtest + +class QuopriTestCase(unittest.TestCase): + # Each entry is a tuple of (plaintext, encoded string). These strings are + # used in the "quotetabs=0" tests. + STRINGS = ( + # Some normal strings + ('hello', 'hello'), + ('''hello + there + world''', '''hello + there + world'''), + ('''hello + there + world +''', '''hello + there + world +'''), + ('\201\202\203', '=81=82=83'), + # Add some trailing MUST QUOTE strings + ('hello ', 'hello=20'), + ('hello\t', 'hello=09'), + # Some long lines. First, a single line of 108 characters + ('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\xd8\xd9\xda\xdb\xdc\xdd\xde\xdfxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', + '''xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx=D8=D9=DA=DB=DC=DD=DE=DFx= +xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'''), + # A line of exactly 76 characters, no soft line break should be needed + ('yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy', + 'yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy'), + # A line of 77 characters, forcing a soft line break at position 75, + # and a second line of exactly 2 characters (because the soft line + # break `=' sign counts against the line length limit). + ('zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz', + '''zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz= +zz'''), + # A line of 151 characters, forcing a soft line break at position 75, + # with a second line of exactly 76 characters and no trailing = + ('zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz', + '''zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz= +zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz'''), + # A string containing a hard line break, but which the first line is + # 151 characters and the second line is exactly 76 characters. This + # should leave us with three lines, the first which has a soft line + # break, and which the second and third do not. + ('''yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy +zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz''', + '''yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy= +yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy +zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz'''), + # Now some really complex stuff ;) + (DECSAMPLE, ENCSAMPLE), + ) + + # These are used in the "quotetabs=1" tests. + ESTRINGS = ( + ('hello world', 'hello=20world'), + ('hello\tworld', 'hello=09world'), + ) + + # These are used in the "header=1" tests. + HSTRINGS = ( + ('hello world', 'hello_world'), + ('hello_world', 'hello=5Fworld'), + ) + + @withpythonimplementation + def test_encodestring(self): + for p, e in self.STRINGS: + self.assertTrue(quopri.encodestring(p) == e) + + @withpythonimplementation + def test_decodestring(self): + for p, e in self.STRINGS: + self.assertTrue(quopri.decodestring(e) == p) + + @withpythonimplementation + def test_idempotent_string(self): + for p, e in self.STRINGS: + self.assertTrue(quopri.decodestring(quopri.encodestring(e)) == e) + + @withpythonimplementation + def test_encode(self): + for p, e in self.STRINGS: + infp = cStringIO.StringIO(p) + outfp = cStringIO.StringIO() + quopri.encode(infp, outfp, quotetabs=False) + self.assertTrue(outfp.getvalue() == e) + + @withpythonimplementation + def test_decode(self): + for p, e in self.STRINGS: + infp = cStringIO.StringIO(e) + outfp = cStringIO.StringIO() + quopri.decode(infp, outfp) + self.assertTrue(outfp.getvalue() == p) + + @withpythonimplementation + def test_embedded_ws(self): + for p, e in self.ESTRINGS: + self.assertTrue(quopri.encodestring(p, quotetabs=True) == e) + self.assertTrue(quopri.decodestring(e) == p) + + @withpythonimplementation + def test_encode_header(self): + for p, e in self.HSTRINGS: + self.assertTrue(quopri.encodestring(p, header=True) == e) + + @withpythonimplementation + def test_decode_header(self): + for p, e in self.HSTRINGS: + self.assertTrue(quopri.decodestring(e, header=True) == p) + + @unittest.expectedFailure + def test_scriptencode(self): + (p, e) = self.STRINGS[-1] + process = subprocess.Popen([sys.executable, "-mquopri"], + stdin=subprocess.PIPE, stdout=subprocess.PIPE) + self.addCleanup(process.stdout.close) + cout, cerr = process.communicate(p) + # On Windows, Python will output the result to stdout using + # CRLF, as the mode of stdout is text mode. To compare this + # with the expected result, we need to do a line-by-line comparison. + self.assertEqual(cout.splitlines(), e.splitlines()) + + @unittest.expectedFailure + def test_scriptdecode(self): + (p, e) = self.STRINGS[-1] + process = subprocess.Popen([sys.executable, "-mquopri", "-d"], + stdin=subprocess.PIPE, stdout=subprocess.PIPE) + self.addCleanup(process.stdout.close) + cout, cerr = process.communicate(e) + self.assertEqual(cout.splitlines(), p.splitlines()) + +def test_main(): + test_support.run_unittest(QuopriTestCase) + + +if __name__ == "__main__": + test_main() diff --git a/third_party/stdlib/test/test_rfc822.py b/third_party/stdlib/test/test_rfc822.py new file mode 100644 index 00000000..0f89ce16 --- /dev/null +++ b/third_party/stdlib/test/test_rfc822.py @@ -0,0 +1,264 @@ +import unittest +from test import test_support + +#rfc822 = test_support.import_module("rfc822", deprecated=True) +import rfc822 + +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO + + +class MessageTestCase(unittest.TestCase): + def create_message(self, msg): + return rfc822.Message(StringIO(msg)) + + def test_get(self): + msg = self.create_message( + 'To: "last, first" \n\ntest\n') + self.assertTrue(msg.get("to") == '"last, first" ') + self.assertTrue(msg.get("TO") == '"last, first" ') + self.assertTrue(msg.get("No-Such-Header") is None) + self.assertTrue(msg.get("No-Such-Header", "No-Such-Value") + == "No-Such-Value") + + def test_setdefault(self): + msg = self.create_message( + 'To: "last, first" \n\ntest\n') + self.assertTrue(not msg.has_key("New-Header")) + self.assertTrue(msg.setdefault("New-Header", "New-Value") == "New-Value") + self.assertTrue(msg.setdefault("New-Header", "Different-Value") + == "New-Value") + self.assertTrue(msg["new-header"] == "New-Value") + + self.assertTrue(msg.setdefault("Another-Header") == "") + self.assertTrue(msg["another-header"] == "") + + def check(self, msg, results): + """Check addresses and the date.""" + m = self.create_message(msg) + i = 0 + for n, a in m.getaddrlist('to') + m.getaddrlist('cc'): + try: + mn, ma = results[i][0], results[i][1] + except IndexError: + print 'extra parsed address:', repr(n), repr(a) + continue + i = i + 1 + self.assertEqual(mn, n, + "Un-expected name: %r != %r" % (mn, n)) + self.assertEqual(ma, a, + "Un-expected address: %r != %r" % (ma, a)) + if mn == n and ma == a: + pass + else: + print 'not found:', repr(n), repr(a) + + out = m.getdate('date') + if out: + self.assertEqual(out, + (1999, 1, 13, 23, 57, 35, 0, 1, 0), + "date conversion failed") + + + # Note: all test cases must have the same date (in various formats), + # or no date! + + def test_basic(self): + self.check( + 'Date: Wed, 13 Jan 1999 23:57:35 -0500\n' + 'From: Guido van Rossum \n' + 'To: "Guido van\n' + '\t : Rossum" \n' + 'Subject: test2\n' + '\n' + 'test2\n', + [('Guido van\n\t : Rossum', 'guido@python.org')]) + + self.check( + 'From: Barry \n' + 'Date: 13-Jan-1999 23:57:35 EST\n' + '\n' + 'test', + [('Guido: the Barbarian', 'guido@python.org'), + ('Guido: the Madman', 'guido@python.org') + ]) + + self.check( + 'To: "The monster with\n' + ' the very long name: Guido" \n' + 'Date: Wed, 13 Jan 1999 23:57:35 -0500\n' + '\n' + 'test', + [('The monster with\n the very long name: Guido', + 'guido@python.org')]) + + self.check( + 'To: "Amit J. Patel" \n' + 'CC: Mike Fletcher ,\n' + ' "\'string-sig@python.org\'" \n' + 'Cc: fooz@bat.com, bart@toof.com\n' + 'Cc: goit@lip.com\n' + 'Date: Wed, 13 Jan 1999 23:57:35 -0500\n' + '\n' + 'test', + [('Amit J. Patel', 'amitp@Theory.Stanford.EDU'), + ('Mike Fletcher', 'mfletch@vrtelecom.com'), + ("'string-sig@python.org'", 'string-sig@python.org'), + ('', 'fooz@bat.com'), + ('', 'bart@toof.com'), + ('', 'goit@lip.com'), + ]) + + self.check( + 'To: Some One \n' + 'From: Anudder Persin \n' + 'Date:\n' + '\n' + 'test', + [('Some One', 'someone@dom.ain')]) + + self.check( + 'To: person@dom.ain (User J. Person)\n\n', + [('User J. Person', 'person@dom.ain')]) + + def test_doublecomment(self): + # The RFC allows comments within comments in an email addr + self.check( + 'To: person@dom.ain ((User J. Person)), John Doe \n\n', + [('User J. Person', 'person@dom.ain'), ('John Doe', 'foo@bar.com')]) + + def test_twisted(self): + # This one is just twisted. I don't know what the proper + # result should be, but it shouldn't be to infloop, which is + # what used to happen! + self.check( + 'To: <[smtp:dd47@mail.xxx.edu]_at_hmhq@hdq-mdm1-imgout.companay.com>\n' + 'Date: Wed, 13 Jan 1999 23:57:35 -0500\n' + '\n' + 'test', + [('', ''), + ('', 'dd47@mail.xxx.edu'), + ('', '_at_hmhq@hdq-mdm1-imgout.companay.com'), + ]) + + def test_commas_in_full_name(self): + # This exercises the old commas-in-a-full-name bug, which + # should be doing the right thing in recent versions of the + # module. + self.check( + 'To: "last, first" \n' + '\n' + 'test', + [('last, first', 'userid@foo.net')]) + + def test_quoted_name(self): + self.check( + 'To: (Comment stuff) "Quoted name"@somewhere.com\n' + '\n' + 'test', + [('Comment stuff', '"Quoted name"@somewhere.com')]) + + def test_bogus_to_header(self): + self.check( + 'To: :\n' + 'Cc: goit@lip.com\n' + 'Date: Wed, 13 Jan 1999 23:57:35 -0500\n' + '\n' + 'test', + [('', 'goit@lip.com')]) + + def test_addr_ipquad(self): + self.check( + 'To: guido@[132.151.1.21]\n' + '\n' + 'foo', + [('', 'guido@[132.151.1.21]')]) + + def test_iter(self): + m = rfc822.Message(StringIO( + 'Date: Wed, 13 Jan 1999 23:57:35 -0500\n' + 'From: Guido van Rossum \n' + 'To: "Guido van\n' + '\t : Rossum" \n' + 'Subject: test2\n' + '\n' + 'test2\n' )) + self.assertEqual(sorted(m), ['date', 'from', 'subject', 'to']) + + def test_rfc2822_phrases(self): + # RFC 2822 (the update to RFC 822) specifies that dots in phrases are + # obsolete syntax, which conforming programs MUST recognize but NEVER + # generate (see $4.1 Miscellaneous obsolete tokens). This is a + # departure from RFC 822 which did not allow dots in non-quoted + # phrases. + self.check('To: User J. Person \n\n', + [('User J. Person', 'person@dom.ain')]) + + # This takes too long to add to the test suite +## def test_an_excrutiatingly_long_address_field(self): +## OBSCENELY_LONG_HEADER_MULTIPLIER = 10000 +## oneaddr = ('Person' * 10) + '@' + ('.'.join(['dom']*10)) + '.com' +## addr = ', '.join([oneaddr] * OBSCENELY_LONG_HEADER_MULTIPLIER) +## lst = rfc822.AddrlistClass(addr).getaddrlist() +## self.assertEqual(len(lst), OBSCENELY_LONG_HEADER_MULTIPLIER) + + def test_2getaddrlist(self): + eq = self.assertEqual + msg = self.create_message("""\ +To: aperson@dom.ain +Cc: bperson@dom.ain +Cc: cperson@dom.ain +Cc: dperson@dom.ain + +A test message. +""") + ccs = [('', a) for a in + ['bperson@dom.ain', 'cperson@dom.ain', 'dperson@dom.ain']] + addrs = msg.getaddrlist('cc') + addrs.sort() + eq(addrs, ccs) + # Try again, this one used to fail + addrs = msg.getaddrlist('cc') + addrs.sort() + eq(addrs, ccs) + + def test_parseaddr(self): + eq = self.assertEqual + eq(rfc822.parseaddr('<>'), ('', '')) + eq(rfc822.parseaddr('aperson@dom.ain'), ('', 'aperson@dom.ain')) + eq(rfc822.parseaddr('bperson@dom.ain (Bea A. Person)'), + ('Bea A. Person', 'bperson@dom.ain')) + eq(rfc822.parseaddr('Cynthia Person '), + ('Cynthia Person', 'cperson@dom.ain')) + + def test_quote_unquote(self): + eq = self.assertEqual + eq(rfc822.quote('foo\\wacky"name'), 'foo\\\\wacky\\"name') + eq(rfc822.unquote('"foo\\\\wacky\\"name"'), 'foo\\wacky"name') + + def test_invalid_headers(self): + eq = self.assertEqual + msg = self.create_message("First: val\n: otherval\nSecond: val2\n") + eq(msg.getheader('First'), 'val') + eq(msg.getheader('Second'), 'val2') + + +def test_main(): + test_support.run_unittest(MessageTestCase) + + +if __name__ == "__main__": + test_main() diff --git a/third_party/stdlib/test/test_sched.py b/third_party/stdlib/test/test_sched.py new file mode 100644 index 00000000..8c09df90 --- /dev/null +++ b/third_party/stdlib/test/test_sched.py @@ -0,0 +1,169 @@ +#import Queue as queue +import sched +import time +import unittest +import test.test_support + +try: + import threading +except ImportError: + threading = None + +TIMEOUT = 10 + + +class Timer(object): + def __init__(self): + self._cond = threading.Condition() + self._time = 0 + self._stop = 0 + + def time(self): + with self._cond: + return self._time + + # increase the time but not beyond the established limit + def sleep(self, t): + assert t >= 0 + with self._cond: + t += self._time + while self._stop < t: + self._time = self._stop + self._cond.wait() + self._time = t + + # advance time limit for user code + def advance(self, t): + assert t >= 0 + with self._cond: + self._stop += t + self._cond.notify_all() + + +class TestCase(unittest.TestCase): + + def test_enter(self): + l = [] + fun = lambda x: l.append(x) + scheduler = sched.scheduler(time.time, time.sleep) + for x in [0.5, 0.4, 0.3, 0.2, 0.1]: + z = scheduler.enter(x, 1, fun, (x,)) + scheduler.run() + self.assertEqual(l, [0.1, 0.2, 0.3, 0.4, 0.5]) + + def test_enterabs(self): + l = [] + fun = lambda x: l.append(x) + scheduler = sched.scheduler(time.time, time.sleep) + for x in [0.05, 0.04, 0.03, 0.02, 0.01]: + z = scheduler.enterabs(x, 1, fun, (x,)) + scheduler.run() + self.assertEqual(l, [0.01, 0.02, 0.03, 0.04, 0.05]) + + #@unittest.skipUnless(threading, 'Threading required for this test.') + @unittest.skip('grumpy') + def test_enter_concurrent(self): + q = queue.Queue() + fun = q.put + timer = Timer() + scheduler = sched.scheduler(timer.time, timer.sleep) + scheduler.enter(1, 1, fun, (1,)) + scheduler.enter(3, 1, fun, (3,)) + t = threading.Thread(target=scheduler.run) + t.start() + timer.advance(1) + self.assertEqual(q.get(timeout=TIMEOUT), 1) + self.assertTrue(q.empty()) + for x in [4, 5, 2]: + z = scheduler.enter(x - 1, 1, fun, (x,)) + timer.advance(2) + self.assertEqual(q.get(timeout=TIMEOUT), 2) + self.assertEqual(q.get(timeout=TIMEOUT), 3) + self.assertTrue(q.empty()) + timer.advance(1) + self.assertEqual(q.get(timeout=TIMEOUT), 4) + self.assertTrue(q.empty()) + timer.advance(1) + self.assertEqual(q.get(timeout=TIMEOUT), 5) + self.assertTrue(q.empty()) + timer.advance(1000) + t.join(timeout=TIMEOUT) + self.assertFalse(t.is_alive()) + self.assertTrue(q.empty()) + self.assertEqual(timer.time(), 5) + + def test_priority(self): + l = [] + fun = lambda x: l.append(x) + scheduler = sched.scheduler(time.time, time.sleep) + for priority in [1, 2, 3, 4, 5]: + z = scheduler.enterabs(0.01, priority, fun, (priority,)) + scheduler.run() + self.assertEqual(l, [1, 2, 3, 4, 5]) + + @unittest.skip('grumpy') + def test_cancel(self): + l = [] + fun = lambda x: l.append(x) + scheduler = sched.scheduler(time.time, time.sleep) + now = time.time() + event1 = scheduler.enterabs(now + 0.01, 1, fun, (0.01,)) + event2 = scheduler.enterabs(now + 0.02, 1, fun, (0.02,)) + event3 = scheduler.enterabs(now + 0.03, 1, fun, (0.03,)) + event4 = scheduler.enterabs(now + 0.04, 1, fun, (0.04,)) + event5 = scheduler.enterabs(now + 0.05, 1, fun, (0.05,)) + scheduler.cancel(event1) + scheduler.cancel(event5) + scheduler.run() + self.assertEqual(l, [0.02, 0.03, 0.04]) + + #@unittest.skipUnless(threading, 'Threading required for this test.') + @unittest.skip('grumpy') + def test_cancel_concurrent(self): + q = queue.Queue() + fun = q.put + timer = Timer() + scheduler = sched.scheduler(timer.time, timer.sleep) + now = timer.time() + event1 = scheduler.enterabs(now + 1, 1, fun, (1,)) + event2 = scheduler.enterabs(now + 2, 1, fun, (2,)) + event4 = scheduler.enterabs(now + 4, 1, fun, (4,)) + event5 = scheduler.enterabs(now + 5, 1, fun, (5,)) + event3 = scheduler.enterabs(now + 3, 1, fun, (3,)) + t = threading.Thread(target=scheduler.run) + t.start() + timer.advance(1) + self.assertEqual(q.get(timeout=TIMEOUT), 1) + self.assertTrue(q.empty()) + scheduler.cancel(event2) + scheduler.cancel(event5) + timer.advance(1) + self.assertTrue(q.empty()) + timer.advance(1) + self.assertEqual(q.get(timeout=TIMEOUT), 3) + self.assertTrue(q.empty()) + timer.advance(1) + self.assertEqual(q.get(timeout=TIMEOUT), 4) + self.assertTrue(q.empty()) + timer.advance(1000) + t.join(timeout=TIMEOUT) + self.assertFalse(t.is_alive()) + self.assertTrue(q.empty()) + self.assertEqual(timer.time(), 4) + + def test_empty(self): + l = [] + fun = lambda x: l.append(x) + scheduler = sched.scheduler(time.time, time.sleep) + self.assertTrue(scheduler.empty()) + for x in [0.05, 0.04, 0.03, 0.02, 0.01]: + z = scheduler.enterabs(x, 1, fun, (x,)) + self.assertFalse(scheduler.empty()) + scheduler.run() + self.assertTrue(scheduler.empty()) + +def test_main(): + test.test_support.run_unittest(TestCase) + +if __name__ == "__main__": + test_main() diff --git a/third_party/stdlib/test/test_select.py b/third_party/stdlib/test/test_select.py new file mode 100644 index 00000000..4b024fc7 --- /dev/null +++ b/third_party/stdlib/test/test_select.py @@ -0,0 +1,67 @@ +from test import test_support +import unittest +import select_ as select +import os +import sys + +@unittest.skipIf(sys.platform[:3] in ('win', 'os2', 'riscos'), + "can't easily test on this system") +class SelectTestCase(unittest.TestCase): + + class Nope(object): + pass + + class Almost(object): + def fileno(self): + return 'fileno' + + def test_error_conditions(self): + self.assertRaises(TypeError, select.select, 1, 2, 3) + self.assertRaises(TypeError, select.select, [self.Nope()], [], []) + self.assertRaises(TypeError, select.select, [self.Almost()], [], []) + self.assertRaises(ValueError, select.select, [], [], [], "not a number") + + def test_returned_list_identity(self): + # See issue #8329 + r, w, x = select.select([], [], [], 1) + self.assertIsNot(r, w) + self.assertIsNot(r, x) + self.assertIsNot(w, x) + + def test_select(self): + cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done' + p = os.popen(cmd, 'r') + for tout in (0, 1, 2, 4, 8, 16) + (None,)*10: + if test_support.verbose: + print 'timeout =', tout + rfd, wfd, xfd = select.select([p], [], [], tout) + if (rfd, wfd, xfd) == ([], [], []): + continue + if (rfd, wfd, xfd) == ([p], [], []): + line = p.readline() + if test_support.verbose: + print repr(line) + if not line: + if test_support.verbose: + print 'EOF' + break + continue + self.fail('Unexpected return values from select():', rfd, wfd, xfd) + p.close() + + # Issue 16230: Crash on select resized list + def test_select_mutated(self): + a = [] + class F(object): + def fileno(self): + del a[-1] + return sys.stdout.fileno() + a[:] = [F()] * 10 + self.assertEqual(select.select([], a, []), ([], a[:5], [])) + +def test_main(): + test_support.run_unittest(SelectTestCase) + test_support.reap_children() + +if __name__ == "__main__": + test_main() diff --git a/third_party/stdlib/test/test_stat.py b/third_party/stdlib/test/test_stat.py new file mode 100644 index 00000000..51f078a1 --- /dev/null +++ b/third_party/stdlib/test/test_stat.py @@ -0,0 +1,180 @@ +import unittest +import os +from test.test_support import TESTFN, run_unittest +import stat + +class TestFilemode(unittest.TestCase): + file_flags = {'SF_APPEND', 'SF_ARCHIVED', 'SF_IMMUTABLE', 'SF_NOUNLINK', + 'SF_SNAPSHOT', 'UF_APPEND', 'UF_COMPRESSED', 'UF_HIDDEN', + 'UF_IMMUTABLE', 'UF_NODUMP', 'UF_NOUNLINK', 'UF_OPAQUE'} + + formats = {'S_IFBLK', 'S_IFCHR', 'S_IFDIR', 'S_IFIFO', 'S_IFLNK', + 'S_IFREG', 'S_IFSOCK'} + + format_funcs = {'S_ISBLK', 'S_ISCHR', 'S_ISDIR', 'S_ISFIFO', 'S_ISLNK', + 'S_ISREG', 'S_ISSOCK'} + + stat_struct = { + 'ST_MODE': 0, + 'ST_INO': 1, + 'ST_DEV': 2, + 'ST_NLINK': 3, + 'ST_UID': 4, + 'ST_GID': 5, + 'ST_SIZE': 6, + 'ST_ATIME': 7, + 'ST_MTIME': 8, + 'ST_CTIME': 9} + + # permission bit value are defined by POSIX + permission_bits = { + 'S_ISUID': 0o4000, + 'S_ISGID': 0o2000, + 'S_ENFMT': 0o2000, + 'S_ISVTX': 0o1000, + 'S_IRWXU': 0o700, + 'S_IRUSR': 0o400, + 'S_IREAD': 0o400, + 'S_IWUSR': 0o200, + 'S_IWRITE': 0o200, + 'S_IXUSR': 0o100, + 'S_IEXEC': 0o100, + 'S_IRWXG': 0o070, + 'S_IRGRP': 0o040, + 'S_IWGRP': 0o020, + 'S_IXGRP': 0o010, + 'S_IRWXO': 0o007, + 'S_IROTH': 0o004, + 'S_IWOTH': 0o002, + 'S_IXOTH': 0o001} + + def setUp(self): + try: + os.remove(TESTFN) + except OSError: + try: + os.rmdir(TESTFN) + except OSError: + pass + tearDown = setUp + + def get_mode(self, fname=TESTFN): #, lstat=True): +# if lstat: +# st_mode = os.lstat(fname).st_mode +# else: + st_mode = os.stat(fname).st_mode + return st_mode + + def assertS_IS(self, name, mode): + # test format, lstrip is for S_IFIFO +# fmt = getattr(stat, "S_IF" + name.lstrip("F")) +# self.assertEqual(stat.S_IFMT(mode), fmt) + # test that just one function returns true + testname = "S_IS" + name + for funcname in self.format_funcs: + func = getattr(stat, funcname, None) + if func is None: + if funcname == testname: + raise ValueError(funcname) + continue + if funcname == testname: + self.assertTrue(func(mode)) + else: + self.assertFalse(func(mode)) + + @unittest.skip('grumpy') + def test_mode(self): + with open(TESTFN, 'w'): + pass + if os.name == 'posix': + os.chmod(TESTFN, 0o700) + st_mode = self.get_mode() + self.assertS_IS("REG", st_mode) + self.assertEqual(stat.S_IMODE(st_mode), + stat.S_IRWXU) + + os.chmod(TESTFN, 0o070) + st_mode = self.get_mode() + self.assertS_IS("REG", st_mode) + self.assertEqual(stat.S_IMODE(st_mode), + stat.S_IRWXG) + + os.chmod(TESTFN, 0o007) + st_mode = self.get_mode() + self.assertS_IS("REG", st_mode) + self.assertEqual(stat.S_IMODE(st_mode), + stat.S_IRWXO) + + os.chmod(TESTFN, 0o444) + st_mode = self.get_mode() + self.assertS_IS("REG", st_mode) + self.assertEqual(stat.S_IMODE(st_mode), 0o444) + else: + os.chmod(TESTFN, 0o700) + st_mode = self.get_mode() + self.assertS_IS("REG", st_mode) + self.assertEqual(stat.S_IFMT(st_mode), + stat.S_IFREG) + + def test_directory(self): + os.mkdir(TESTFN) + os.chmod(TESTFN, 0o700) + st_mode = self.get_mode() + self.assertS_IS("DIR", st_mode) + + @unittest.skip('grumpy') + @unittest.skipUnless(hasattr(os, 'symlink'), 'os.symlink not available') + def test_link(self): + try: + os.symlink(os.getcwd(), TESTFN) + except (OSError, NotImplementedError) as err: + raise unittest.SkipTest(str(err)) + else: + st_mode = self.get_mode() + self.assertS_IS("LNK", st_mode) + + @unittest.skip('grumpy') + @unittest.skipUnless(hasattr(os, 'mkfifo'), 'os.mkfifo not available') + def test_fifo(self): + os.mkfifo(TESTFN, 0o700) + st_mode = self.get_mode() + self.assertS_IS("FIFO", st_mode) + + @unittest.skip('grumpy') + @unittest.skipUnless(os.name == 'posix', 'requires Posix') + def test_devices(self): + if os.path.exists(os.devnull): + st_mode = self.get_mode(os.devnull, lstat=False) + self.assertS_IS("CHR", st_mode) + # Linux block devices, BSD has no block devices anymore + for blockdev in ("/dev/sda", "/dev/hda"): + if os.path.exists(blockdev): + st_mode = self.get_mode(blockdev, lstat=False) + self.assertS_IS("BLK", st_mode) + break + + @unittest.skip('grumpy') + def test_module_attributes(self): + for key, value in self.stat_struct.items(): + modvalue = getattr(stat, key) + self.assertEqual(value, modvalue, key) + for key, value in self.permission_bits.items(): + modvalue = getattr(stat, key) + self.assertEqual(value, modvalue, key) + for key in self.file_flags: + modvalue = getattr(stat, key) + self.assertIsInstance(modvalue, int) + for key in self.formats: + modvalue = getattr(stat, key) + self.assertIsInstance(modvalue, int) + for key in self.format_funcs: + func = getattr(stat, key) + self.assertTrue(callable(func)) + self.assertEqual(func(0), 0) + + +def test_main(): + run_unittest(TestFilemode) + +if __name__ == '__main__': + test_main() diff --git a/third_party/stdlib/test/test_support.py b/third_party/stdlib/test/test_support.py index 3bc18782..906ace0c 100644 --- a/third_party/stdlib/test/test_support.py +++ b/third_party/stdlib/test/test_support.py @@ -20,10 +20,10 @@ # import time # import struct # import sysconfig -# try: -# import thread -# except ImportError: -# thread = None +try: + import thread +except ImportError: + thread = None __all__ = [ "Error", "TestFailed", "have_unicode", "BasicTestRunner", "run_unittest", @@ -188,67 +188,67 @@ class TestFailed(Error): # except KeyError: # pass -# if sys.platform.startswith("win"): -# def _waitfor(func, pathname, waitall=False): -# # Perform the operation -# func(pathname) -# # Now setup the wait loop -# if waitall: -# dirname = pathname -# else: -# dirname, name = os.path.split(pathname) -# dirname = dirname or '.' -# # Check for `pathname` to be removed from the filesystem. -# # The exponential backoff of the timeout amounts to a total -# # of ~1 second after which the deletion is probably an error -# # anyway. -# # Testing on a i7@4.3GHz shows that usually only 1 iteration is -# # required when contention occurs. -# timeout = 0.001 -# while timeout < 1.0: -# # Note we are only testing for the existence of the file(s) in -# # the contents of the directory regardless of any security or -# # access rights. If we have made it this far, we have sufficient -# # permissions to do that much using Python's equivalent of the -# # Windows API FindFirstFile. -# # Other Windows APIs can fail or give incorrect results when -# # dealing with files that are pending deletion. -# L = os.listdir(dirname) -# if not (L if waitall else name in L): -# return -# # Increase the timeout and try again -# time.sleep(timeout) -# timeout *= 2 -# warnings.warn('tests may fail, delete still pending for ' + pathname, -# RuntimeWarning, stacklevel=4) - -# def _unlink(filename): -# _waitfor(os.unlink, filename) - -# def _rmdir(dirname): -# _waitfor(os.rmdir, dirname) - -# def _rmtree(path): -# def _rmtree_inner(path): -# for name in os.listdir(path): -# fullname = os.path.join(path, name) -# if os.path.isdir(fullname): -# _waitfor(_rmtree_inner, fullname, waitall=True) -# os.rmdir(fullname) -# else: -# os.unlink(fullname) -# _waitfor(_rmtree_inner, path, waitall=True) -# _waitfor(os.rmdir, path) -# else: -# _unlink = os.unlink -# _rmdir = os.rmdir -# _rmtree = shutil.rmtree - -# def unlink(filename): -# try: -# _unlink(filename) -# except OSError: -# pass +if sys.platform.startswith("win"): + def _waitfor(func, pathname, waitall=False): + # Perform the operation + func(pathname) + # Now setup the wait loop + if waitall: + dirname = pathname + else: + dirname, name = os.path.split(pathname) + dirname = dirname or '.' + # Check for `pathname` to be removed from the filesystem. + # The exponential backoff of the timeout amounts to a total + # of ~1 second after which the deletion is probably an error + # anyway. + # Testing on a i7@4.3GHz shows that usually only 1 iteration is + # required when contention occurs. + timeout = 0.001 + while timeout < 1.0: + # Note we are only testing for the existence of the file(s) in + # the contents of the directory regardless of any security or + # access rights. If we have made it this far, we have sufficient + # permissions to do that much using Python's equivalent of the + # Windows API FindFirstFile. + # Other Windows APIs can fail or give incorrect results when + # dealing with files that are pending deletion. + L = os.listdir(dirname) + if not (L if waitall else name in L): + return + # Increase the timeout and try again + time.sleep(timeout) + timeout *= 2 + warnings.warn('tests may fail, delete still pending for ' + pathname, + RuntimeWarning, stacklevel=4) + + def _unlink(filename): + _waitfor(os.unlink, filename) + + def _rmdir(dirname): + _waitfor(os.rmdir, dirname) + + def _rmtree(path): + def _rmtree_inner(path): + for name in os.listdir(path): + fullname = os.path.join(path, name) + if os.path.isdir(fullname): + _waitfor(_rmtree_inner, fullname, waitall=True) + os.rmdir(fullname) + else: + os.unlink(fullname) + _waitfor(_rmtree_inner, path, waitall=True) + _waitfor(os.rmdir, path) +else: + _unlink = os.unlink + _rmdir = os.rmdir +# _rmtree = shutil.rmtree + +def unlink(filename): + try: + _unlink(filename) + except OSError: + pass # def rmdir(dirname): # try: @@ -569,14 +569,14 @@ class TestFailed(Error): except NameError: have_unicode = False -# requires_unicode = unittest.skipUnless(have_unicode, 'no unicode support') +requires_unicode = unittest.skipUnless(have_unicode, 'no unicode support') # def u(s): # return unicode(s, 'unicode-escape') -# # FS_NONASCII: non-ASCII Unicode character encodable by -# # sys.getfilesystemencoding(), or None if there is no such character. -# FS_NONASCII = None +# FS_NONASCII: non-ASCII Unicode character encodable by +# sys.getfilesystemencoding(), or None if there is no such character. +FS_NONASCII = None # if have_unicode: # for character in ( # # First try printable and common characters to have a readable filename. @@ -620,14 +620,14 @@ class TestFailed(Error): # FS_NONASCII = character # break -# # Filename used for testing -# if os.name == 'java': -# # Jython disallows @ in module names -# TESTFN = '$test' -# elif os.name == 'riscos': -# TESTFN = 'testfile' -# else: -# TESTFN = '@test' +# Filename used for testing +if os.name == 'java': + # Jython disallows @ in module names + TESTFN = '$test' +elif os.name == 'riscos': + TESTFN = 'testfile' +else: + TESTFN = '@test' # # Unicode name only used if TEST_FN_ENCODING exists for the platform. # if have_unicode: # # Assuming sys.getfilesystemencoding()!=sys.getdefaultencoding() @@ -666,10 +666,9 @@ class TestFailed(Error): # 'Unicode filename tests may not be effective' \ # % TESTFN_UNENCODABLE - -# # Disambiguate TESTFN for parallel testing, while letting it remain a valid -# # module name. -# TESTFN = "{}_{}_tmp".format(TESTFN, os.getpid()) +# Disambiguate TESTFN for parallel testing, while letting it remain a valid +# module name. +TESTFN = "%s_%s_tmp" % (TESTFN, os.getpid()) # # Save the initial cwd # SAVEDCWD = os.getcwd() @@ -1512,34 +1511,34 @@ def run_unittest(*classes): # print 'doctest (%s) ... %d tests with zero failures' % (module.__name__, t) # return f, t -# #======================================================================= -# # Threading support to prevent reporting refleaks when running regrtest.py -R - -# # NOTE: we use thread._count() rather than threading.enumerate() (or the -# # moral equivalent thereof) because a threading.Thread object is still alive -# # until its __bootstrap() method has returned, even after it has been -# # unregistered from the threading module. -# # thread._count(), on the other hand, only gets decremented *after* the -# # __bootstrap() method has returned, which gives us reliable reference counts -# # at the end of a test run. - -# def threading_setup(): -# if thread: -# return thread._count(), -# else: -# return 1, +#======================================================================= +# Threading support to prevent reporting refleaks when running regrtest.py -R + +# NOTE: we use thread._count() rather than threading.enumerate() (or the +# moral equivalent thereof) because a threading.Thread object is still alive +# until its __bootstrap() method has returned, even after it has been +# unregistered from the threading module. +# thread._count(), on the other hand, only gets decremented *after* the +# __bootstrap() method has returned, which gives us reliable reference counts +# at the end of a test run. + +def threading_setup(): + if thread: + return (thread._count(),) + else: + return (1,) -# def threading_cleanup(nb_threads): -# if not thread: -# return +def threading_cleanup(nb_threads): + if not thread: + return -# _MAX_COUNT = 10 -# for count in range(_MAX_COUNT): -# n = thread._count() -# if n == nb_threads: -# break -# time.sleep(0.1) -# # XXX print a warning in case of failure? + _MAX_COUNT = 10 + for count in range(_MAX_COUNT): + n = thread._count() + if n == nb_threads: + break + time.sleep(0.1) + # XXX print a warning in case of failure? # def reap_threads(func): # """Use this function when threads are being used. This will @@ -1558,25 +1557,25 @@ def run_unittest(*classes): # threading_cleanup(*key) # return decorator -# def reap_children(): -# """Use this function at the end of test_main() whenever sub-processes -# are started. This will help ensure that no extra children (zombies) -# stick around to hog resources and create problems when looking -# for refleaks. -# """ +def reap_children(): + """Use this function at the end of test_main() whenever sub-processes + are started. This will help ensure that no extra children (zombies) + stick around to hog resources and create problems when looking + for refleaks. + """ -# # Reap all our dead child processes so we don't leave zombies around. -# # These hog resources and might be causing some of the buildbots to die. -# if hasattr(os, 'waitpid'): -# any_process = -1 -# while True: -# try: -# # This will raise an exception on Windows. That's ok. -# pid, status = os.waitpid(any_process, os.WNOHANG) -# if pid == 0: -# break -# except: -# break + # Reap all our dead child processes so we don't leave zombies around. + # These hog resources and might be causing some of the buildbots to die. + if hasattr(os, 'waitpid'): + any_process = -1 + while True: + try: + # This will raise an exception on Windows. That's ok. + pid, status = os.waitpid(any_process, os.WNOHANG) + if pid == 0: + break + except: + break # @contextlib.contextmanager # def start_threads(threads, unlock=None): diff --git a/third_party/stdlib/test/test_threading.py b/third_party/stdlib/test/test_threading.py new file mode 100644 index 00000000..d8f2a5fa --- /dev/null +++ b/third_party/stdlib/test/test_threading.py @@ -0,0 +1,953 @@ +# Very rudimentary test of threading module + +import test.test_support +from test.test_support import verbose, cpython_only +#from test.script_helper import assert_python_ok + +import random +import re +import sys +#thread = test.test_support.import_module('thread') +import thread +#threading = test.test_support.import_module('threading') +import threading +import time +import unittest +import weakref +import os +#import subprocess +#try: +# import _testcapi +#except ImportError: +_testcapi = None + +from test import lock_tests + +# A trivial mutable counter. +class Counter(object): + def __init__(self): + self.value = 0 + def inc(self): + self.value += 1 + def dec(self): + self.value -= 1 + def get(self): + return self.value + +class TestThread(threading.Thread): + def __init__(self, name, testcase, sema, mutex, nrunning): + threading.Thread.__init__(self, name=name) + self.testcase = testcase + self.sema = sema + self.mutex = mutex + self.nrunning = nrunning + + def run(self): + delay = random.random() / 10000.0 + if verbose: + print 'task %s will run for %s usec' % ( + self.name, delay * 1e6) + + with self.sema: + with self.mutex: + self.nrunning.inc() + if verbose: + print self.nrunning.get(), 'tasks are running' + self.testcase.assertLessEqual(self.nrunning.get(), 3) + + time.sleep(delay) + if verbose: + print 'task', self.name, 'done' + + with self.mutex: + self.nrunning.dec() + self.testcase.assertGreaterEqual(self.nrunning.get(), 0) + if verbose: + print '%s is finished. %d tasks are running' % ( + self.name, self.nrunning.get()) + +class BaseTestCase(unittest.TestCase): + def setUp(self): + self._threads = test.test_support.threading_setup() + + def tearDown(self): + test.test_support.threading_cleanup(*self._threads) + test.test_support.reap_children() + + +class ThreadTests(BaseTestCase): + + # Create a bunch of threads, let each do some work, wait until all are + # done. + def test_various_ops(self): + # This takes about n/3 seconds to run (about n/3 clumps of tasks, + # times about 1 second per clump). + NUMTASKS = 10 + + # no more than 3 of the 10 can run at once + sema = threading.BoundedSemaphore(value=3) + mutex = threading.RLock() + numrunning = Counter() + + threads = [] + + for i in range(NUMTASKS): + t = TestThread(""%i, self, sema, mutex, numrunning) + threads.append(t) + self.assertIsNone(t.ident) + self.assertRegexpMatches(repr(t), r'^$') + t.start() + + if verbose: + print 'waiting for all tasks to complete' + for t in threads: + t.join(NUMTASKS) + self.assertFalse(t.is_alive()) + self.assertNotEqual(t.ident, 0) + self.assertIsNotNone(t.ident) + self.assertRegexpMatches(repr(t), r'^$') + if verbose: + print 'all tasks done' + self.assertEqual(numrunning.get(), 0) + + def test_ident_of_no_threading_threads(self): + # The ident still must work for the main thread and dummy threads. + self.assertIsNotNone(threading.currentThread().ident) + def f(): + ident.append(threading.currentThread().ident) + done.set() + done = threading.Event() + ident = [] + thread.start_new_thread(f, ()) + done.wait() + self.assertIsNotNone(ident[0]) + # Kill the "immortal" _DummyThread + del threading._active[ident[0]] + + # run with a small(ish) thread stack size (256kB) + def test_various_ops_small_stack(self): + if verbose: + print 'with 256kB thread stack size...' + try: + threading.stack_size(262144) + except thread.error: + self.skipTest('platform does not support changing thread stack size') + self.test_various_ops() + threading.stack_size(0) + + # run with a large thread stack size (1MB) + def test_various_ops_large_stack(self): + if verbose: + print 'with 1MB thread stack size...' + try: + threading.stack_size(0x100000) + except thread.error: + self.skipTest('platform does not support changing thread stack size') + self.test_various_ops() + threading.stack_size(0) + + def test_foreign_thread(self): + # Check that a "foreign" thread can use the threading module. + def f(mutex): + # Calling current_thread() forces an entry for the foreign + # thread to get made in the threading._active map. + threading.current_thread() + mutex.release() + + mutex = threading.Lock() + mutex.acquire() + tid = thread.start_new_thread(f, (mutex,)) + # Wait for the thread to finish. + mutex.acquire() + self.assertIn(tid, threading._active) + self.assertIsInstance(threading._active[tid], threading._DummyThread) + del threading._active[tid] + + # PyThreadState_SetAsyncExc() is a CPython-only gimmick, not (currently) + # exposed at the Python level. This test relies on ctypes to get at it. + @unittest.skip('grumpy') + def test_PyThreadState_SetAsyncExc(self): + try: + #import ctypes + pass + except ImportError: + self.skipTest('requires ctypes') + + set_async_exc = ctypes.pythonapi.PyThreadState_SetAsyncExc + + class AsyncExc(Exception): + pass + + exception = ctypes.py_object(AsyncExc) + + # First check it works when setting the exception from the same thread. + tid = thread.get_ident() + + try: + result = set_async_exc(ctypes.c_long(tid), exception) + # The exception is async, so we might have to keep the VM busy until + # it notices. + while True: + pass + except AsyncExc: + pass + else: + # This code is unreachable but it reflects the intent. If we wanted + # to be smarter the above loop wouldn't be infinite. + self.fail("AsyncExc not raised") + try: + self.assertEqual(result, 1) # one thread state modified + except UnboundLocalError: + # The exception was raised too quickly for us to get the result. + pass + + # `worker_started` is set by the thread when it's inside a try/except + # block waiting to catch the asynchronously set AsyncExc exception. + # `worker_saw_exception` is set by the thread upon catching that + # exception. + worker_started = threading.Event() + worker_saw_exception = threading.Event() + + class Worker(threading.Thread): + def run(self): + self.id = thread.get_ident() + self.finished = False + + try: + while True: + worker_started.set() + time.sleep(0.1) + except AsyncExc: + self.finished = True + worker_saw_exception.set() + + t = Worker() + t.daemon = True # so if this fails, we don't hang Python at shutdown + t.start() + if verbose: + print " started worker thread" + + # Try a thread id that doesn't make sense. + if verbose: + print " trying nonsensical thread id" + result = set_async_exc(ctypes.c_long(-1), exception) + self.assertEqual(result, 0) # no thread states modified + + # Now raise an exception in the worker thread. + if verbose: + print " waiting for worker thread to get started" + ret = worker_started.wait() + self.assertTrue(ret) + if verbose: + print " verifying worker hasn't exited" + self.assertFalse(t.finished) + if verbose: + print " attempting to raise asynch exception in worker" + result = set_async_exc(ctypes.c_long(t.id), exception) + self.assertEqual(result, 1) # one thread state modified + if verbose: + print " waiting for worker to say it caught the exception" + worker_saw_exception.wait(timeout=10) + self.assertTrue(t.finished) + if verbose: + print " all OK -- joining worker" + if t.finished: + t.join() + # else the thread is still running, and we have no way to kill it + + def test_limbo_cleanup(self): + # Issue 7481: Failure to start thread should cleanup the limbo map. + def fail_new_thread(*args): + raise thread.error() + _start_new_thread = threading._start_new_thread + threading._start_new_thread = fail_new_thread + try: + t = threading.Thread(target=lambda: None) + self.assertRaises(thread.error, t.start) + self.assertFalse( + t in threading._limbo, + "Failed to cleanup _limbo map on failure of Thread.start().") + finally: + threading._start_new_thread = _start_new_thread + + @unittest.skip('grumpy') + def test_finalize_runnning_thread(self): + # Issue 1402: the PyGILState_Ensure / _Release functions may be called + # very late on python exit: on deallocation of a running thread for + # example. + try: + #import ctypes + pass + except ImportError: + self.skipTest('requires ctypes') + + rc = subprocess.call([sys.executable, "-c", """if 1: + import ctypes, sys, time, thread + + # This lock is used as a simple event variable. + ready = thread.allocate_lock() + ready.acquire() + + # Module globals are cleared before __del__ is run + # So we save the functions in class dict + class C: + ensure = ctypes.pythonapi.PyGILState_Ensure + release = ctypes.pythonapi.PyGILState_Release + def __del__(self): + state = self.ensure() + self.release(state) + + def waitingThread(): + x = C() + ready.release() + time.sleep(100) + + thread.start_new_thread(waitingThread, ()) + ready.acquire() # Be sure the other thread is waiting. + sys.exit(42) + """]) + self.assertEqual(rc, 42) + + @unittest.skip('grumpy') + def test_finalize_with_trace(self): + # Issue1733757 + # Avoid a deadlock when sys.settrace steps into threading._shutdown + p = subprocess.Popen([sys.executable, "-c", """if 1: + import sys, threading + + # A deadlock-killer, to prevent the + # testsuite to hang forever + def killer(): + import os, time + time.sleep(2) + print 'program blocked; aborting' + os._exit(2) + t = threading.Thread(target=killer) + t.daemon = True + t.start() + + # This is the trace function + def func(frame, event, arg): + threading.current_thread() + return func + + sys.settrace(func) + """], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + self.addCleanup(p.stdout.close) + self.addCleanup(p.stderr.close) + stdout, stderr = p.communicate() + rc = p.returncode + self.assertFalse(rc == 2, "interpreted was blocked") + self.assertTrue(rc == 0, + "Unexpected error: " + repr(stderr)) + + @unittest.skip('grumpy') + def test_join_nondaemon_on_shutdown(self): + # Issue 1722344 + # Raising SystemExit skipped threading._shutdown + p = subprocess.Popen([sys.executable, "-c", """if 1: + import threading + from time import sleep + + def child(): + sleep(1) + # As a non-daemon thread we SHOULD wake up and nothing + # should be torn down yet + print "Woke up, sleep function is:", sleep + + threading.Thread(target=child).start() + raise SystemExit + """], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + self.addCleanup(p.stdout.close) + self.addCleanup(p.stderr.close) + stdout, stderr = p.communicate() + self.assertEqual(stdout.strip(), + "Woke up, sleep function is: ") + stderr = re.sub(r"^\[\d+ refs\]", "", stderr, re.MULTILINE).strip() + self.assertEqual(stderr, "") + + @unittest.skip('grumpy') + def test_enumerate_after_join(self): + # Try hard to trigger #1703448: a thread is still returned in + # threading.enumerate() after it has been join()ed. + enum = threading.enumerate + old_interval = sys.getcheckinterval() + try: + for i in xrange(1, 100): + # Try a couple times at each thread-switching interval + # to get more interleavings. + sys.setcheckinterval(i // 5) + t = threading.Thread(target=lambda: None) + t.start() + t.join() + l = enum() + self.assertNotIn(t, l, + "#1703448 triggered after %d trials: %s" % (i, l)) + finally: + sys.setcheckinterval(old_interval) + + @unittest.skip('grumpy') + def test_no_refcycle_through_target(self): + class RunSelfFunction(object): + def __init__(self, should_raise): + # The links in this refcycle from Thread back to self + # should be cleaned up when the thread completes. + self.should_raise = should_raise + self.thread = threading.Thread(target=self._run, + args=(self,), + kwargs={'yet_another':self}) + self.thread.start() + + def _run(self, other_ref, yet_another): + if self.should_raise: + raise SystemExit + + cyclic_object = RunSelfFunction(should_raise=False) + weak_cyclic_object = weakref.ref(cyclic_object) + cyclic_object.thread.join() + del cyclic_object + self.assertEqual(None, weak_cyclic_object(), + msg=('%d references still around' % + sys.getrefcount(weak_cyclic_object()))) + + raising_cyclic_object = RunSelfFunction(should_raise=True) + weak_raising_cyclic_object = weakref.ref(raising_cyclic_object) + raising_cyclic_object.thread.join() + del raising_cyclic_object + self.assertEqual(None, weak_raising_cyclic_object(), + msg=('%d references still around' % + sys.getrefcount(weak_raising_cyclic_object()))) + + @unittest.skip('grumpy') + @unittest.skipUnless(hasattr(os, 'fork'), 'test needs fork()') + def test_dummy_thread_after_fork(self): + # Issue #14308: a dummy thread in the active list doesn't mess up + # the after-fork mechanism. + code = """if 1: + import thread, threading, os, time + + def background_thread(evt): + # Creates and registers the _DummyThread instance + threading.current_thread() + evt.set() + time.sleep(10) + + evt = threading.Event() + thread.start_new_thread(background_thread, (evt,)) + evt.wait() + assert threading.active_count() == 2, threading.active_count() + if os.fork() == 0: + assert threading.active_count() == 1, threading.active_count() + os._exit(0) + else: + os.wait() + """ + _, out, err = assert_python_ok("-c", code) + self.assertEqual(out, '') + self.assertEqual(err, '') + + @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()") + def test_is_alive_after_fork(self): + # Try hard to trigger #18418: is_alive() could sometimes be True on + # threads that vanished after a fork. + old_interval = sys.getcheckinterval() + + # Make the bug more likely to manifest. + sys.setcheckinterval(10) + + try: + for i in range(20): + t = threading.Thread(target=lambda: None) + t.start() + pid = os.fork() + if pid == 0: + os._exit(1 if t.is_alive() else 0) + else: + t.join() + pid, status = os.waitpid(pid, 0) + self.assertEqual(0, status) + finally: + sys.setcheckinterval(old_interval) + + def test_BoundedSemaphore_limit(self): + # BoundedSemaphore should raise ValueError if released too often. + for limit in range(1, 10): + bs = threading.BoundedSemaphore(limit) + threads = [threading.Thread(target=bs.acquire) + for _ in range(limit)] + for t in threads: + t.start() + for t in threads: + t.join() + threads = [threading.Thread(target=bs.release) + for _ in range(limit)] + for t in threads: + t.start() + for t in threads: + t.join() + self.assertRaises(ValueError, bs.release) + +class ThreadJoinOnShutdown(BaseTestCase): + + # Between fork() and exec(), only async-safe functions are allowed (issues + # #12316 and #11870), and fork() from a worker thread is known to trigger + # problems with some operating systems (issue #3863): skip problematic tests + # on platforms known to behave badly. + platforms_to_skip = ('freebsd4', 'freebsd5', 'freebsd6', 'netbsd5', + 'os2emx') + + def _run_and_join(self, script): + script = """if 1: + import sys, os, time, threading + + # a thread, which waits for the main program to terminate + def joiningfunc(mainthread): + mainthread.join() + print 'end of thread' + \n""" + script + + p = subprocess.Popen([sys.executable, "-c", script], stdout=subprocess.PIPE) + rc = p.wait() + data = p.stdout.read().replace('\r', '') + p.stdout.close() + self.assertEqual(data, "end of main\nend of thread\n") + self.assertFalse(rc == 2, "interpreter was blocked") + self.assertTrue(rc == 0, "Unexpected error") + + @unittest.skip('grumpy') + def test_1_join_on_shutdown(self): + # The usual case: on exit, wait for a non-daemon thread + script = """if 1: + import os + t = threading.Thread(target=joiningfunc, + args=(threading.current_thread(),)) + t.start() + time.sleep(0.1) + print 'end of main' + """ + self._run_and_join(script) + + + @unittest.skip('grumpy') + @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()") + @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug") + def test_2_join_in_forked_process(self): + # Like the test above, but from a forked interpreter + script = """if 1: + childpid = os.fork() + if childpid != 0: + os.waitpid(childpid, 0) + sys.exit(0) + + t = threading.Thread(target=joiningfunc, + args=(threading.current_thread(),)) + t.start() + print 'end of main' + """ + self._run_and_join(script) + + @unittest.skip('grumpy') + @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()") + @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug") + def test_3_join_in_forked_from_thread(self): + # Like the test above, but fork() was called from a worker thread + # In the forked process, the main Thread object must be marked as stopped. + script = """if 1: + main_thread = threading.current_thread() + def worker(): + childpid = os.fork() + if childpid != 0: + os.waitpid(childpid, 0) + sys.exit(0) + + t = threading.Thread(target=joiningfunc, + args=(main_thread,)) + print 'end of main' + t.start() + t.join() # Should not block: main_thread is already stopped + + w = threading.Thread(target=worker) + w.start() + """ + self._run_and_join(script) + + def assertScriptHasOutput(self, script, expected_output): + p = subprocess.Popen([sys.executable, "-c", script], + stdout=subprocess.PIPE) + rc = p.wait() + data = p.stdout.read().decode().replace('\r', '') + self.assertEqual(rc, 0, "Unexpected error") + self.assertEqual(data, expected_output) + + @unittest.skip('grumpy') + @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()") + @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug") + def test_4_joining_across_fork_in_worker_thread(self): + # There used to be a possible deadlock when forking from a child + # thread. See http://bugs.python.org/issue6643. + + # The script takes the following steps: + # - The main thread in the parent process starts a new thread and then + # tries to join it. + # - The join operation acquires the Lock inside the thread's _block + # Condition. (See threading.py:Thread.join().) + # - We stub out the acquire method on the condition to force it to wait + # until the child thread forks. (See LOCK ACQUIRED HERE) + # - The child thread forks. (See LOCK HELD and WORKER THREAD FORKS + # HERE) + # - The main thread of the parent process enters Condition.wait(), + # which releases the lock on the child thread. + # - The child process returns. Without the necessary fix, when the + # main thread of the child process (which used to be the child thread + # in the parent process) attempts to exit, it will try to acquire the + # lock in the Thread._block Condition object and hang, because the + # lock was held across the fork. + + script = """if 1: + import os, time, threading + + finish_join = False + start_fork = False + + def worker(): + # Wait until this thread's lock is acquired before forking to + # create the deadlock. + global finish_join + while not start_fork: + time.sleep(0.01) + # LOCK HELD: Main thread holds lock across this call. + childpid = os.fork() + finish_join = True + if childpid != 0: + # Parent process just waits for child. + os.waitpid(childpid, 0) + # Child process should just return. + + w = threading.Thread(target=worker) + + # Stub out the private condition variable's lock acquire method. + # This acquires the lock and then waits until the child has forked + # before returning, which will release the lock soon after. If + # someone else tries to fix this test case by acquiring this lock + # before forking instead of resetting it, the test case will + # deadlock when it shouldn't. + condition = w._block + orig_acquire = condition.acquire + call_count_lock = threading.Lock() + call_count = 0 + def my_acquire(): + global call_count + global start_fork + orig_acquire() # LOCK ACQUIRED HERE + start_fork = True + if call_count == 0: + while not finish_join: + time.sleep(0.01) # WORKER THREAD FORKS HERE + with call_count_lock: + call_count += 1 + condition.acquire = my_acquire + + w.start() + w.join() + print('end of main') + """ + self.assertScriptHasOutput(script, "end of main\n") + + @unittest.skip('grumpy') + @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()") + @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug") + def test_5_clear_waiter_locks_to_avoid_crash(self): + # Check that a spawned thread that forks doesn't segfault on certain + # platforms, namely OS X. This used to happen if there was a waiter + # lock in the thread's condition variable's waiters list. Even though + # we know the lock will be held across the fork, it is not safe to + # release locks held across forks on all platforms, so releasing the + # waiter lock caused a segfault on OS X. Furthermore, since locks on + # OS X are (as of this writing) implemented with a mutex + condition + # variable instead of a semaphore, while we know that the Python-level + # lock will be acquired, we can't know if the internal mutex will be + # acquired at the time of the fork. + + script = """if True: + import os, time, threading + + start_fork = False + + def worker(): + # Wait until the main thread has attempted to join this thread + # before continuing. + while not start_fork: + time.sleep(0.01) + childpid = os.fork() + if childpid != 0: + # Parent process just waits for child. + (cpid, rc) = os.waitpid(childpid, 0) + assert cpid == childpid + assert rc == 0 + print('end of worker thread') + else: + # Child process should just return. + pass + + w = threading.Thread(target=worker) + + # Stub out the private condition variable's _release_save method. + # This releases the condition's lock and flips the global that + # causes the worker to fork. At this point, the problematic waiter + # lock has been acquired once by the waiter and has been put onto + # the waiters list. + condition = w._block + orig_release_save = condition._release_save + def my_release_save(): + global start_fork + orig_release_save() + # Waiter lock held here, condition lock released. + start_fork = True + condition._release_save = my_release_save + + w.start() + w.join() + print('end of main thread') + """ + output = "end of worker thread\nend of main thread\n" + self.assertScriptHasOutput(script, output) + + @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()") + @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug") + def test_reinit_tls_after_fork(self): + # Issue #13817: fork() would deadlock in a multithreaded program with + # the ad-hoc TLS implementation. + + def do_fork_and_wait(): + # just fork a child process and wait it + pid = os.fork() + if pid > 0: + os.waitpid(pid, 0) + else: + os._exit(0) + + # start a bunch of threads that will fork() child processes + threads = [] + for i in range(16): + t = threading.Thread(target=do_fork_and_wait) + threads.append(t) + t.start() + + for t in threads: + t.join() + + @cpython_only + @unittest.skipIf(_testcapi is None, "need _testcapi module") + def test_frame_tstate_tracing(self): + # Issue #14432: Crash when a generator is created in a C thread that is + # destroyed while the generator is still used. The issue was that a + # generator contains a frame, and the frame kept a reference to the + # Python state of the destroyed C thread. The crash occurs when a trace + # function is setup. + + def noop_trace(frame, event, arg): + # no operation + return noop_trace + + def generator(): + while 1: + yield "generator" + + def callback(): + if callback.gen is None: + callback.gen = generator() + return next(callback.gen) + callback.gen = None + + old_trace = sys.gettrace() + sys.settrace(noop_trace) + try: + # Install a trace function + threading.settrace(noop_trace) + + # Create a generator in a C thread which exits after the call + _testcapi.call_in_temporary_c_thread(callback) + + # Call the generator in a different Python thread, check that the + # generator didn't keep a reference to the destroyed thread state + for test in range(3): + # The trace function is still called here + callback() + finally: + sys.settrace(old_trace) + + +class ThreadingExceptionTests(BaseTestCase): + # A RuntimeError should be raised if Thread.start() is called + # multiple times. + def test_start_thread_again(self): + thread = threading.Thread() + thread.start() + self.assertRaises(RuntimeError, thread.start) + + def test_joining_current_thread(self): + current_thread = threading.current_thread() + self.assertRaises(RuntimeError, current_thread.join); + + def test_joining_inactive_thread(self): + thread = threading.Thread() + self.assertRaises(RuntimeError, thread.join) + + def test_daemonize_active_thread(self): + thread = threading.Thread() + thread.start() + self.assertRaises(RuntimeError, setattr, thread, "daemon", True) + + @unittest.skip('grumpy') + def test_print_exception(self): + script = r"""if 1: + import threading + import time + + running = False + def run(): + global running + running = True + while running: + time.sleep(0.01) + 1.0/0.0 + t = threading.Thread(target=run) + t.start() + while not running: + time.sleep(0.01) + running = False + t.join() + """ + rc, out, err = assert_python_ok("-c", script) + self.assertEqual(out, '') + self.assertIn("Exception in thread", err) + self.assertIn("Traceback (most recent call last):", err) + self.assertIn("ZeroDivisionError", err) + self.assertNotIn("Unhandled exception", err) + + @unittest.skip('grumpy') + def test_print_exception_stderr_is_none_1(self): + script = r"""if 1: + import sys + import threading + import time + + running = False + def run(): + global running + running = True + while running: + time.sleep(0.01) + 1.0/0.0 + t = threading.Thread(target=run) + t.start() + while not running: + time.sleep(0.01) + sys.stderr = None + running = False + t.join() + """ + rc, out, err = assert_python_ok("-c", script) + self.assertEqual(out, '') + self.assertIn("Exception in thread", err) + self.assertIn("Traceback (most recent call last):", err) + self.assertIn("ZeroDivisionError", err) + self.assertNotIn("Unhandled exception", err) + + @unittest.skip('grumpy') + def test_print_exception_stderr_is_none_2(self): + script = r"""if 1: + import sys + import threading + import time + + running = False + def run(): + global running + running = True + while running: + time.sleep(0.01) + 1.0/0.0 + sys.stderr = None + t = threading.Thread(target=run) + t.start() + while not running: + time.sleep(0.01) + running = False + t.join() + """ + rc, out, err = assert_python_ok("-c", script) + self.assertEqual(out, '') + self.assertNotIn("Unhandled exception", err) + + +class LockTests(lock_tests.LockTests): + locktype = staticmethod(threading.Lock) + +class RLockTests(lock_tests.RLockTests): + locktype = staticmethod(threading.RLock) + +class EventTests(lock_tests.EventTests): + eventtype = staticmethod(threading.Event) + +class ConditionAsRLockTests(lock_tests.RLockTests): + # Condition uses an RLock by default and exports its API. + locktype = staticmethod(threading.Condition) + +class ConditionTests(lock_tests.ConditionTests): + condtype = staticmethod(threading.Condition) + +class SemaphoreTests(lock_tests.SemaphoreTests): + semtype = staticmethod(threading.Semaphore) + +class BoundedSemaphoreTests(lock_tests.BoundedSemaphoreTests): + semtype = staticmethod(threading.BoundedSemaphore) + + @unittest.skip('grumpy') + @unittest.skipUnless(sys.platform == 'darwin', 'test macosx problem') + def test_recursion_limit(self): + # Issue 9670 + # test that excessive recursion within a non-main thread causes + # an exception rather than crashing the interpreter on platforms + # like Mac OS X or FreeBSD which have small default stack sizes + # for threads + script = """if True: + import threading + + def recurse(): + return recurse() + + def outer(): + try: + recurse() + except RuntimeError: + pass + + w = threading.Thread(target=outer) + w.start() + w.join() + print('end of main thread') + """ + expected_output = "end of main thread\n" + p = subprocess.Popen([sys.executable, "-c", script], + stdout=subprocess.PIPE) + stdout, stderr = p.communicate() + data = stdout.decode().replace('\r', '') + self.assertEqual(p.returncode, 0, "Unexpected error") + self.assertEqual(data, expected_output) + +def test_main(): + test.test_support.run_unittest(LockTests, RLockTests, EventTests, + ConditionAsRLockTests, ConditionTests, + SemaphoreTests, BoundedSemaphoreTests, + ThreadTests, + ThreadJoinOnShutdown, + ThreadingExceptionTests, + ) + +if __name__ == "__main__": + test_main() diff --git a/third_party/stdlib/test/test_uu.py b/third_party/stdlib/test/test_uu.py new file mode 100644 index 00000000..db998815 --- /dev/null +++ b/third_party/stdlib/test/test_uu.py @@ -0,0 +1,221 @@ +""" +Tests for uu module. +Nick Mathewson +""" + +import unittest +from test import test_support + +import sys, os, uu, cStringIO +import uu + +plaintext = "The smooth-scaled python crept over the sleeping dog\n" + +encodedtext = """\ +M5&AE('-M;V]T:\"US8V%L960@<'ET:&]N(&-R97!T(&]V97(@=&AE('-L965P +(:6YG(&1O9PH """ + +encodedtextwrapped = "begin %03o %s\n" + encodedtext.replace("%", "%%") + "\n \nend\n" + +class UUTest(unittest.TestCase): + + def test_encode(self): + inp = cStringIO.StringIO(plaintext) + out = cStringIO.StringIO() + uu.encode(inp, out, "t1") + self.assertEqual(out.getvalue(), encodedtextwrapped % (0666, "t1")) + inp = cStringIO.StringIO(plaintext) + out = cStringIO.StringIO() + uu.encode(inp, out, "t1", 0644) + self.assertEqual(out.getvalue(), encodedtextwrapped % (0644, "t1")) + + def test_decode(self): + inp = cStringIO.StringIO(encodedtextwrapped % (0666, "t1")) + out = cStringIO.StringIO() + uu.decode(inp, out) + self.assertEqual(out.getvalue(), plaintext) + inp = cStringIO.StringIO( + "UUencoded files may contain many lines,\n" + + "even some that have 'begin' in them.\n" + + encodedtextwrapped % (0666, "t1") + ) + out = cStringIO.StringIO() + uu.decode(inp, out) + self.assertEqual(out.getvalue(), plaintext) + + def test_truncatedinput(self): + inp = cStringIO.StringIO("begin 644 t1\n" + encodedtext) + out = cStringIO.StringIO() + try: + uu.decode(inp, out) + self.fail("No exception raised") + except uu.Error, e: + self.assertEqual(str(e), "Truncated input file") + + def test_missingbegin(self): + inp = cStringIO.StringIO("") + out = cStringIO.StringIO() + try: + uu.decode(inp, out) + self.fail("No exception raised") + except uu.Error, e: + self.assertEqual(str(e), "No valid begin line found in input file") + + def test_garbage_padding(self): + # Issue #22406 + encodedtext = ( + "begin 644 file\n" + # length 1; bits 001100 111111 111111 111111 + "\x21\x2C\x5F\x5F\x5F\n" + "\x20\n" + "end\n" + ) + plaintext = "\x33" # 00110011 + + inp = cStringIO.StringIO(encodedtext) + out = cStringIO.StringIO() + uu.decode(inp, out, quiet=True) + self.assertEqual(out.getvalue(), plaintext) + + #import codecs + #decoded = codecs.decode(encodedtext, "uu_codec") + #self.assertEqual(decoded, plaintext) + +class UUStdIOTest(unittest.TestCase): + + def setUp(self): + self.stdin = sys.stdin + self.stdout = sys.stdout + + def tearDown(self): + sys.stdin = self.stdin + sys.stdout = self.stdout + + def test_encode(self): + sys.stdin = cStringIO.StringIO(plaintext) + sys.stdout = cStringIO.StringIO() + uu.encode("-", "-", "t1", 0666) + self.assertEqual( + sys.stdout.getvalue(), + encodedtextwrapped % (0666, "t1") + ) + + def test_decode(self): + sys.stdin = cStringIO.StringIO(encodedtextwrapped % (0666, "t1")) + sys.stdout = cStringIO.StringIO() + uu.decode("-", "-") + self.assertEqual(sys.stdout.getvalue(), plaintext) + +class UUFileTest(unittest.TestCase): + + def _kill(self, f): + # close and remove file + try: + f.close() + except (SystemExit, KeyboardInterrupt): + raise + except: + pass + try: + os.unlink(f.name) + except (SystemExit, KeyboardInterrupt): + raise +# except: +# pass + + def setUp(self): + self.tmpin = test_support.TESTFN + "i" + self.tmpout = test_support.TESTFN + "o" + + def tearDown(self): + del self.tmpin + del self.tmpout + + def test_encode(self): + fin = fout = None + try: + test_support.unlink(self.tmpin) + fin = open(self.tmpin, 'wb') + fin.write(plaintext) + fin.close() + + fin = open(self.tmpin, 'rb') + fout = open(self.tmpout, 'w') + uu.encode(fin, fout, self.tmpin, mode=0644) + fin.close() + fout.close() + + fout = open(self.tmpout, 'r') + s = fout.read() + fout.close() + self.assertEqual(s, encodedtextwrapped % (0644, self.tmpin)) + + # in_file and out_file as filenames + uu.encode(self.tmpin, self.tmpout, self.tmpin, mode=0644) + fout = open(self.tmpout, 'r') + s = fout.read() + fout.close() + self.assertEqual(s, encodedtextwrapped % (0644, self.tmpin)) + + finally: + self._kill(fin) + self._kill(fout) + + def test_decode(self): + f = None + try: + test_support.unlink(self.tmpin) + f = open(self.tmpin, 'w') + f.write(encodedtextwrapped % (0644, self.tmpout)) + f.close() + + f = open(self.tmpin, 'r') + uu.decode(f) + f.close() + + f = open(self.tmpout, 'r') + s = f.read() + f.close() + self.assertEqual(s, plaintext) + # XXX is there an xp way to verify the mode? + finally: + self._kill(f) + + def test_decode_filename(self): + f = None + try: + test_support.unlink(self.tmpin) + f = open(self.tmpin, 'w') + f.write(encodedtextwrapped % (0644, self.tmpout)) + f.close() + + uu.decode(self.tmpin) + + f = open(self.tmpout, 'r') + s = f.read() + f.close() + self.assertEqual(s, plaintext) + finally: + self._kill(f) + + def test_decodetwice(self): + # Verify that decode() will refuse to overwrite an existing file + f = None + try: + f = cStringIO.StringIO(encodedtextwrapped % (0644, self.tmpout)) + + f = open(self.tmpin, 'r') + uu.decode(f) + f.close() + + f = open(self.tmpin, 'r') + self.assertRaises(uu.Error, uu.decode, f) + f.close() + finally: + self._kill(f) + +def test_main(): + test_support.run_unittest(UUTest, UUStdIOTest, UUFileTest) + +if __name__=="__main__": + test_main() diff --git a/third_party/stdlib/threading.py b/third_party/stdlib/threading.py new file mode 100644 index 00000000..1c34c848 --- /dev/null +++ b/third_party/stdlib/threading.py @@ -0,0 +1,1414 @@ +"""Thread module emulating a subset of Java's threading model.""" + +import sys as _sys + +try: + import thread +except ImportError: + del _sys.modules[__name__] + raise + +import warnings + +from collections import deque as _deque +from itertools import count as _count +from time import time as _time, sleep as _sleep +from traceback import format_exc as _format_exc + +# Note regarding PEP 8 compliant aliases +# This threading model was originally inspired by Java, and inherited +# the convention of camelCase function and method names from that +# language. While those names are not in any imminent danger of being +# deprecated, starting with Python 2.6, the module now provides a +# PEP 8 compliant alias for any such method name. +# Using the new PEP 8 compliant names also facilitates substitution +# with the multiprocessing module, which doesn't provide the old +# Java inspired names. + + +# Rename some stuff so "from threading import *" is safe +__all__ = ['activeCount', 'active_count', 'Condition', 'currentThread', + 'current_thread', 'enumerate', 'Event', + 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Thread', + 'Timer', 'setprofile', 'settrace', 'local', 'stack_size'] + +_start_new_thread = thread.start_new_thread +_allocate_lock = thread.allocate_lock +_get_ident = thread.get_ident +ThreadError = thread.error +del thread + + +# sys.exc_clear is used to work around the fact that except blocks +# don't fully clear the exception until 3.0. +warnings.filterwarnings('ignore', category=DeprecationWarning, + module='threading', message='sys.exc_clear') + +# Debug support (adapted from ihooks.py). +# All the major classes here derive from _Verbose. We force that to +# be a new-style class so that all the major classes here are new-style. +# This helps debugging (type(instance) is more revealing for instances +# of new-style classes). + +_VERBOSE = False + +if __debug__: + + class _Verbose(object): + + def __init__(self, verbose=None): + if verbose is None: + verbose = _VERBOSE + self.__verbose = verbose + + def _note(self, format, *args): + if self.__verbose: + format = format % args + # Issue #4188: calling current_thread() can incur an infinite + # recursion if it has to create a DummyThread on the fly. + ident = _get_ident() + try: + name = _active[ident].name + except KeyError: + name = "" % ident + format = "%s: %s\n" % (name, format) + _sys.stderr.write(format) + +else: + # Disable this when using "python -O" + class _Verbose(object): + def __init__(self, verbose=None): + pass + def _note(self, *args): + pass + +# Support for profile and trace hooks + +_profile_hook = None +_trace_hook = None + +def setprofile(func): + """Set a profile function for all threads started from the threading module. + + The func will be passed to sys.setprofile() for each thread, before its + run() method is called. + + """ + global _profile_hook + _profile_hook = func + +def settrace(func): + """Set a trace function for all threads started from the threading module. + + The func will be passed to sys.settrace() for each thread, before its run() + method is called. + + """ + global _trace_hook + _trace_hook = func + +# Synchronization classes + +Lock = _allocate_lock + +def RLock(*args, **kwargs): + """Factory function that returns a new reentrant lock. + + A reentrant lock must be released by the thread that acquired it. Once a + thread has acquired a reentrant lock, the same thread may acquire it again + without blocking; the thread must release it once for each time it has + acquired it. + + """ + return _RLock(*args, **kwargs) + +class _RLock(_Verbose): + """A reentrant lock must be released by the thread that acquired it. Once a + thread has acquired a reentrant lock, the same thread may acquire it + again without blocking; the thread must release it once for each time it + has acquired it. + """ + + def __init__(self, verbose=None): + _Verbose.__init__(self, verbose) + self.__block = _allocate_lock() + self.__owner = None + self.__count = 0 + + def __repr__(self): + owner = self.__owner + try: + owner = _active[owner].name + except KeyError: + pass + return "<%s owner=%r count=%d>" % ( + self.__class__.__name__, owner, self.__count) + + def acquire(self, blocking=1): + """Acquire a lock, blocking or non-blocking. + + When invoked without arguments: if this thread already owns the lock, + increment the recursion level by one, and return immediately. Otherwise, + if another thread owns the lock, block until the lock is unlocked. Once + the lock is unlocked (not owned by any thread), then grab ownership, set + the recursion level to one, and return. If more than one thread is + blocked waiting until the lock is unlocked, only one at a time will be + able to grab ownership of the lock. There is no return value in this + case. + + When invoked with the blocking argument set to true, do the same thing + as when called without arguments, and return true. + + When invoked with the blocking argument set to false, do not block. If a + call without an argument would block, return false immediately; + otherwise, do the same thing as when called without arguments, and + return true. + + """ + me = _get_ident() + if self.__owner == me: + self.__count = self.__count + 1 + if __debug__: + self._note("%s.acquire(%s): recursive success", self, blocking) + return 1 + rc = self.__block.acquire(blocking) + if rc: + self.__owner = me + self.__count = 1 + if __debug__: + self._note("%s.acquire(%s): initial success", self, blocking) + else: + if __debug__: + self._note("%s.acquire(%s): failure", self, blocking) + return rc + + __enter__ = acquire + + def release(self): + """Release a lock, decrementing the recursion level. + + If after the decrement it is zero, reset the lock to unlocked (not owned + by any thread), and if any other threads are blocked waiting for the + lock to become unlocked, allow exactly one of them to proceed. If after + the decrement the recursion level is still nonzero, the lock remains + locked and owned by the calling thread. + + Only call this method when the calling thread owns the lock. A + RuntimeError is raised if this method is called when the lock is + unlocked. + + There is no return value. + + """ + if self.__owner != _get_ident(): + raise RuntimeError("cannot release un-acquired lock") + self.__count = count = self.__count - 1 + if not count: + self.__owner = None + self.__block.release() + if __debug__: + self._note("%s.release(): final release", self) + else: + if __debug__: + self._note("%s.release(): non-final release", self) + + def __exit__(self, t, v, tb): + self.release() + + # Internal methods used by condition variables + + def _acquire_restore(self, count_owner): + count, owner = count_owner + self.__block.acquire() + self.__count = count + self.__owner = owner + if __debug__: + self._note("%s._acquire_restore()", self) + + def _release_save(self): + if __debug__: + self._note("%s._release_save()", self) + count = self.__count + self.__count = 0 + owner = self.__owner + self.__owner = None + self.__block.release() + return (count, owner) + + def _is_owned(self): + return self.__owner == _get_ident() + + +def Condition(*args, **kwargs): + """Factory function that returns a new condition variable object. + + A condition variable allows one or more threads to wait until they are + notified by another thread. + + If the lock argument is given and not None, it must be a Lock or RLock + object, and it is used as the underlying lock. Otherwise, a new RLock object + is created and used as the underlying lock. + + """ + return _Condition(*args, **kwargs) + +class _Condition(_Verbose): + """Condition variables allow one or more threads to wait until they are + notified by another thread. + """ + + def __init__(self, lock=None, verbose=None): + _Verbose.__init__(self, verbose) + if lock is None: + lock = RLock() + self.__lock = lock + # Export the lock's acquire() and release() methods + self.acquire = lock.acquire + self.release = lock.release + # If the lock defines _release_save() and/or _acquire_restore(), + # these override the default implementations (which just call + # release() and acquire() on the lock). Ditto for _is_owned(). + try: + self._release_save = lock._release_save + except AttributeError: + pass + try: + self._acquire_restore = lock._acquire_restore + except AttributeError: + pass + try: + self._is_owned = lock._is_owned + except AttributeError: + pass + self.__waiters = [] + + def __enter__(self): + return self.__lock.__enter__() + + def __exit__(self, *args): + return self.__lock.__exit__(*args) + + def __repr__(self): + return "" % (self.__lock, len(self.__waiters)) + + def _release_save(self): + self.__lock.release() # No state to save + + def _acquire_restore(self, x): + self.__lock.acquire() # Ignore saved state + + def _is_owned(self): + # Return True if lock is owned by current_thread. + # This method is called only if __lock doesn't have _is_owned(). + if self.__lock.acquire(0): + self.__lock.release() + return False + else: + return True + + def wait(self, timeout=None): + """Wait until notified or until a timeout occurs. + + If the calling thread has not acquired the lock when this method is + called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks until it is + awakened by a notify() or notifyAll() call for the same condition + variable in another thread, or until the optional timeout occurs. Once + awakened or timed out, it re-acquires the lock and returns. + + When the timeout argument is present and not None, it should be a + floating point number specifying a timeout for the operation in seconds + (or fractions thereof). + + When the underlying lock is an RLock, it is not released using its + release() method, since this may not actually unlock the lock when it + was acquired multiple times recursively. Instead, an internal interface + of the RLock class is used, which really unlocks it even when it has + been recursively acquired several times. Another internal interface is + then used to restore the recursion level when the lock is reacquired. + + """ + if not self._is_owned(): + raise RuntimeError("cannot wait on un-acquired lock") + waiter = _allocate_lock() + waiter.acquire() + self.__waiters.append(waiter) + saved_state = self._release_save() + try: # restore state no matter what (e.g., KeyboardInterrupt) + if timeout is None: + waiter.acquire() + if __debug__: + self._note("%s.wait(): got it", self) + else: + # Balancing act: We can't afford a pure busy loop, so we + # have to sleep; but if we sleep the whole timeout time, + # we'll be unresponsive. The scheme here sleeps very + # little at first, longer as time goes on, but never longer + # than 20 times per second (or the timeout time remaining). + endtime = _time() + timeout + delay = 0.0005 # 500 us -> initial delay of 1 ms + while True: + gotit = waiter.acquire(0) + if gotit: + break + remaining = endtime - _time() + if remaining <= 0: + break + delay = min(delay * 2, remaining, .05) + _sleep(delay) + if not gotit: + if __debug__: + self._note("%s.wait(%s): timed out", self, timeout) + try: + self.__waiters.remove(waiter) + except ValueError: + pass + else: + if __debug__: + self._note("%s.wait(%s): got it", self, timeout) + finally: + self._acquire_restore(saved_state) + + def notify(self, n=1): + """Wake up one or more threads waiting on this condition, if any. + + If the calling thread has not acquired the lock when this method is + called, a RuntimeError is raised. + + This method wakes up at most n of the threads waiting for the condition + variable; it is a no-op if no threads are waiting. + + """ + if not self._is_owned(): + raise RuntimeError("cannot notify on un-acquired lock") + __waiters = self.__waiters + waiters = __waiters[:n] + if not waiters: + if __debug__: + self._note("%s.notify(): no waiters", self) + return + self._note("%s.notify(): notifying %d waiter%s", self, n, + n!=1 and "s" or "") + for waiter in waiters: + waiter.release() + try: + __waiters.remove(waiter) + except ValueError: + pass + + def notifyAll(self): + """Wake up all threads waiting on this condition. + + If the calling thread has not acquired the lock when this method + is called, a RuntimeError is raised. + + """ + self.notify(len(self.__waiters)) + + notify_all = notifyAll + + +def Semaphore(*args, **kwargs): + """A factory function that returns a new semaphore. + + Semaphores manage a counter representing the number of release() calls minus + the number of acquire() calls, plus an initial value. The acquire() method + blocks if necessary until it can return without making the counter + negative. If not given, value defaults to 1. + + """ + return _Semaphore(*args, **kwargs) + +class _Semaphore(_Verbose): + """Semaphores manage a counter representing the number of release() calls + minus the number of acquire() calls, plus an initial value. The acquire() + method blocks if necessary until it can return without making the counter + negative. If not given, value defaults to 1. + + """ + + # After Tim Peters' semaphore class, but not quite the same (no maximum) + + def __init__(self, value=1, verbose=None): + if value < 0: + raise ValueError("semaphore initial value must be >= 0") + _Verbose.__init__(self, verbose) + self.__cond = Condition(Lock()) + self.__value = value + + def acquire(self, blocking=1): + """Acquire a semaphore, decrementing the internal counter by one. + + When invoked without arguments: if the internal counter is larger than + zero on entry, decrement it by one and return immediately. If it is zero + on entry, block, waiting until some other thread has called release() to + make it larger than zero. This is done with proper interlocking so that + if multiple acquire() calls are blocked, release() will wake exactly one + of them up. The implementation may pick one at random, so the order in + which blocked threads are awakened should not be relied on. There is no + return value in this case. + + When invoked with blocking set to true, do the same thing as when called + without arguments, and return true. + + When invoked with blocking set to false, do not block. If a call without + an argument would block, return false immediately; otherwise, do the + same thing as when called without arguments, and return true. + + """ + rc = False + with self.__cond: + while self.__value == 0: + if not blocking: + break + if __debug__: + self._note("%s.acquire(%s): blocked waiting, value=%s", + self, blocking, self.__value) + self.__cond.wait() + else: + self.__value = self.__value - 1 + if __debug__: + self._note("%s.acquire: success, value=%s", + self, self.__value) + rc = True + return rc + + __enter__ = acquire + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + + When the counter is zero on entry and another thread is waiting for it + to become larger than zero again, wake up that thread. + + """ + with self.__cond: + self.__value = self.__value + 1 + if __debug__: + self._note("%s.release: success, value=%s", + self, self.__value) + self.__cond.notify() + + def __exit__(self, t, v, tb): + self.release() + + +def BoundedSemaphore(*args, **kwargs): + """A factory function that returns a new bounded semaphore. + + A bounded semaphore checks to make sure its current value doesn't exceed its + initial value. If it does, ValueError is raised. In most situations + semaphores are used to guard resources with limited capacity. + + If the semaphore is released too many times it's a sign of a bug. If not + given, value defaults to 1. + + Like regular semaphores, bounded semaphores manage a counter representing + the number of release() calls minus the number of acquire() calls, plus an + initial value. The acquire() method blocks if necessary until it can return + without making the counter negative. If not given, value defaults to 1. + + """ + return _BoundedSemaphore(*args, **kwargs) + +class _BoundedSemaphore(_Semaphore): + """A bounded semaphore checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. In most situations + semaphores are used to guard resources with limited capacity. + """ + + def __init__(self, value=1, verbose=None): + _Semaphore.__init__(self, value, verbose) + self._initial_value = value + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + + When the counter is zero on entry and another thread is waiting for it + to become larger than zero again, wake up that thread. + + If the number of releases exceeds the number of acquires, + raise a ValueError. + + """ + with self.__cond: + if self.__value >= self._initial_value: + raise ValueError("Semaphore released too many times") + self.__value += 1 + self.__cond.notify() + + +def Event(*args, **kwargs): + """A factory function that returns a new event. + + Events manage a flag that can be set to true with the set() method and reset + to false with the clear() method. The wait() method blocks until the flag is + true. + + """ + return _Event(*args, **kwargs) + +class _Event(_Verbose): + """A factory function that returns a new event object. An event manages a + flag that can be set to true with the set() method and reset to false + with the clear() method. The wait() method blocks until the flag is true. + + """ + + # After Tim Peters' event class (without is_posted()) + + def __init__(self, verbose=None): + _Verbose.__init__(self, verbose) + self.__cond = Condition(Lock()) + self.__flag = False + + def _reset_internal_locks(self): + # private! called by Thread._reset_internal_locks by _after_fork() + self.__cond.__init__(Lock()) + + def isSet(self): + 'Return true if and only if the internal flag is true.' + return self.__flag + + is_set = isSet + + def set(self): + """Set the internal flag to true. + + All threads waiting for the flag to become true are awakened. Threads + that call wait() once the flag is true will not block at all. + + """ + with self.__cond: + self.__flag = True + self.__cond.notify_all() + + def clear(self): + """Reset the internal flag to false. + + Subsequently, threads calling wait() will block until set() is called to + set the internal flag to true again. + + """ + with self.__cond: + self.__flag = False + + def wait(self, timeout=None): + """Block until the internal flag is true. + + If the internal flag is true on entry, return immediately. Otherwise, + block until another thread calls set() to set the flag to true, or until + the optional timeout occurs. + + When the timeout argument is present and not None, it should be a + floating point number specifying a timeout for the operation in seconds + (or fractions thereof). + + This method returns the internal flag on exit, so it will always return + True except if a timeout is given and the operation times out. + + """ + with self.__cond: + if not self.__flag: + self.__cond.wait(timeout) + return self.__flag + +# Helper to generate new thread names +_counter = _count().next +_counter() # Consume 0 so first non-main thread has id 1. +def _newname(template="Thread-%d"): + return template % _counter() + +# Active thread administration +_active_limbo_lock = _allocate_lock() +_active = {} # maps thread id to Thread object +_limbo = {} + + +# Main class for threads + +class Thread(_Verbose): + """A class that represents a thread of control. + + This class can be safely subclassed in a limited fashion. + + """ + __initialized = False + + def __init__(self, group=None, target=None, name=None, + args=(), kwargs=None, verbose=None): + """This constructor should always be called with keyword arguments. Arguments are: + + *group* should be None; reserved for future extension when a ThreadGroup + class is implemented. + + *target* is the callable object to be invoked by the run() + method. Defaults to None, meaning nothing is called. + + *name* is the thread name. By default, a unique name is constructed of + the form "Thread-N" where N is a small decimal number. + + *args* is the argument tuple for the target invocation. Defaults to (). + + *kwargs* is a dictionary of keyword arguments for the target + invocation. Defaults to {}. + + If a subclass overrides the constructor, it must make sure to invoke + the base class constructor (Thread.__init__()) before doing anything + else to the thread. + +""" + assert group is None, "group argument must be None for now" + _Verbose.__init__(self, verbose) + if kwargs is None: + kwargs = {} + self.__target = target + self.__name = str(name or _newname()) + self.__args = args + self.__kwargs = kwargs + self.__daemonic = self._set_daemon() + self.__ident = None + self.__started = Event() + self.__stopped = False + self.__block = Condition(Lock()) + self.__initialized = True + # sys.stderr is not stored in the class like + # sys.exc_info since it can be changed between instances + self.__stderr = _sys.stderr + + def _reset_internal_locks(self): + # private! Called by _after_fork() to reset our internal locks as + # they may be in an invalid state leading to a deadlock or crash. + if hasattr(self, '__block'): # DummyThread deletes self.__block + self.__block.__init__() + self.__started._reset_internal_locks() + + @property + def _block(self): + # used by a unittest + return self.__block + + def _set_daemon(self): + # Overridden in _MainThread and _DummyThread + return current_thread().daemon + + def __repr__(self): + assert self.__initialized, "Thread.__init__() was not called" + status = "initial" + if self.__started.is_set(): + status = "started" + if self.__stopped: + status = "stopped" + if self.__daemonic: + status += " daemon" + if self.__ident is not None: + status += " %s" % self.__ident + return "<%s(%s, %s)>" % (self.__class__.__name__, self.__name, status) + + def start(self): + """Start the thread's activity. + + It must be called at most once per thread object. It arranges for the + object's run() method to be invoked in a separate thread of control. + + This method will raise a RuntimeError if called more than once on the + same thread object. + + """ + if not self.__initialized: + raise RuntimeError("thread.__init__() not called") + if self.__started.is_set(): + raise RuntimeError("threads can only be started once") + if __debug__: + self._note("%s.start(): starting thread", self) + with _active_limbo_lock: + _limbo[self] = self + try: + _start_new_thread(self.__bootstrap, ()) + except Exception: + with _active_limbo_lock: + del _limbo[self] + raise + self.__started.wait() + + def run(self): + """Method representing the thread's activity. + + You may override this method in a subclass. The standard run() method + invokes the callable object passed to the object's constructor as the + target argument, if any, with sequential and keyword arguments taken + from the args and kwargs arguments, respectively. + + """ + try: + if self.__target: + self.__target(*self.__args, **self.__kwargs) + finally: + # Avoid a refcycle if the thread is running a function with + # an argument that has a member that points to the thread. + del self.__target, self.__args, self.__kwargs + + def __bootstrap(self): + # Wrapper around the real bootstrap code that ignores + # exceptions during interpreter cleanup. Those typically + # happen when a daemon thread wakes up at an unfortunate + # moment, finds the world around it destroyed, and raises some + # random exception *** while trying to report the exception in + # __bootstrap_inner() below ***. Those random exceptions + # don't help anybody, and they confuse users, so we suppress + # them. We suppress them only when it appears that the world + # indeed has already been destroyed, so that exceptions in + # __bootstrap_inner() during normal business hours are properly + # reported. Also, we only suppress them for daemonic threads; + # if a non-daemonic encounters this, something else is wrong. + try: + self.__bootstrap_inner() + except: + if self.__daemonic and _sys is None: + return + raise + + def _set_ident(self): + self.__ident = _get_ident() + + def __bootstrap_inner(self): + try: + self._set_ident() + self.__started.set() + with _active_limbo_lock: + _active[self.__ident] = self + del _limbo[self] + if __debug__: + self._note("%s.__bootstrap(): thread started", self) + + if _trace_hook: + self._note("%s.__bootstrap(): registering trace hook", self) + _sys.settrace(_trace_hook) + if _profile_hook: + self._note("%s.__bootstrap(): registering profile hook", self) + _sys.setprofile(_profile_hook) + + try: + self.run() + except SystemExit: + if __debug__: + self._note("%s.__bootstrap(): raised SystemExit", self) + except: + if __debug__: + self._note("%s.__bootstrap(): unhandled exception", self) + # If sys.stderr is no more (most likely from interpreter + # shutdown) use self.__stderr. Otherwise still use sys (as in + # _sys) in case sys.stderr was redefined since the creation of + # self. + if _sys and _sys.stderr is not None: + print>>_sys.stderr, ("Exception in thread %s:\n%s" % + (self.name, _format_exc())) + elif self.__stderr is not None: + # Do the best job possible w/o a huge amt. of code to + # approximate a traceback (code ideas from + # Lib/traceback.py) + exc_type, exc_value, exc_tb = _sys.exc_info() + try: + print>>self.__stderr, ( + "Exception in thread " + self.name + + " (most likely raised during interpreter shutdown):") + print>>self.__stderr, ( + "Traceback (most recent call last):") + while exc_tb: + print>>self.__stderr, ( + ' File "%s", line %s, in %s' % + (exc_tb.tb_frame.f_code.co_filename, + exc_tb.tb_lineno, + exc_tb.tb_frame.f_code.co_name)) + exc_tb = exc_tb.tb_next + print>>self.__stderr, ("%s: %s" % (exc_type, exc_value)) + # Make sure that exc_tb gets deleted since it is a memory + # hog; deleting everything else is just for thoroughness + finally: + del exc_type, exc_value, exc_tb + else: + if __debug__: + self._note("%s.__bootstrap(): normal return", self) + finally: + # Prevent a race in + # test_threading.test_no_refcycle_through_target when + # the exception keeps the target alive past when we + # assert that it's dead. + _sys.exc_clear() + finally: + with _active_limbo_lock: + self.__stop() + try: + # We don't call self.__delete() because it also + # grabs _active_limbo_lock. + del _active[_get_ident()] + except: + pass + + def __stop(self): + # DummyThreads delete self.__block, but they have no waiters to + # notify anyway (join() is forbidden on them). + if not hasattr(self, '__block'): + return + self.__block.acquire() + self.__stopped = True + self.__block.notify_all() + self.__block.release() + + def __delete(self): + "Remove current thread from the dict of currently running threads." + + # Notes about running with dummy_thread: + # + # Must take care to not raise an exception if dummy_thread is being + # used (and thus this module is being used as an instance of + # dummy_threading). dummy_thread.get_ident() always returns -1 since + # there is only one thread if dummy_thread is being used. Thus + # len(_active) is always <= 1 here, and any Thread instance created + # overwrites the (if any) thread currently registered in _active. + # + # An instance of _MainThread is always created by 'threading'. This + # gets overwritten the instant an instance of Thread is created; both + # threads return -1 from dummy_thread.get_ident() and thus have the + # same key in the dict. So when the _MainThread instance created by + # 'threading' tries to clean itself up when atexit calls this method + # it gets a KeyError if another Thread instance was created. + # + # This all means that KeyError from trying to delete something from + # _active if dummy_threading is being used is a red herring. But + # since it isn't if dummy_threading is *not* being used then don't + # hide the exception. + + try: + with _active_limbo_lock: + del _active[_get_ident()] + # There must not be any python code between the previous line + # and after the lock is released. Otherwise a tracing function + # could try to acquire the lock again in the same thread, (in + # current_thread()), and would block. + except KeyError: + if 'dummy_threading' not in _sys.modules: + raise + + def join(self, timeout=None): + """Wait until the thread terminates. + + This blocks the calling thread until the thread whose join() method is + called terminates -- either normally or through an unhandled exception + or until the optional timeout occurs. + + When the timeout argument is present and not None, it should be a + floating point number specifying a timeout for the operation in seconds + (or fractions thereof). As join() always returns None, you must call + isAlive() after join() to decide whether a timeout happened -- if the + thread is still alive, the join() call timed out. + + When the timeout argument is not present or None, the operation will + block until the thread terminates. + + A thread can be join()ed many times. + + join() raises a RuntimeError if an attempt is made to join the current + thread as that would cause a deadlock. It is also an error to join() a + thread before it has been started and attempts to do so raises the same + exception. + + """ + if not self.__initialized: + raise RuntimeError("Thread.__init__() not called") + if not self.__started.is_set(): + raise RuntimeError("cannot join thread before it is started") + if self is current_thread(): + raise RuntimeError("cannot join current thread") + + if __debug__: + if not self.__stopped: + self._note("%s.join(): waiting until thread stops", self) + self.__block.acquire() + try: + if timeout is None: + while not self.__stopped: + self.__block.wait() + if __debug__: + self._note("%s.join(): thread stopped", self) + else: + deadline = _time() + timeout + while not self.__stopped: + delay = deadline - _time() + if delay <= 0: + if __debug__: + self._note("%s.join(): timed out", self) + break + self.__block.wait(delay) + else: + if __debug__: + self._note("%s.join(): thread stopped", self) + finally: + self.__block.release() + + def _name_getter(self): + """A string used for identification purposes only. + + It has no semantics. Multiple threads may be given the same name. The + initial name is set by the constructor. + + """ + assert self.__initialized, "Thread.__init__() not called" + return self.__name + + def _name_setter(self, name): + assert self.__initialized, "Thread.__init__() not called" + self.__name = str(name) + + name = property(_name_getter, _name_setter) + + @property + def ident(self): + """Thread identifier of this thread or None if it has not been started. + + This is a nonzero integer. See the thread.get_ident() function. Thread + identifiers may be recycled when a thread exits and another thread is + created. The identifier is available even after the thread has exited. + + """ + assert self.__initialized, "Thread.__init__() not called" + return self.__ident + + def isAlive(self): + """Return whether the thread is alive. + + This method returns True just before the run() method starts until just + after the run() method terminates. The module function enumerate() + returns a list of all alive threads. + + """ + assert self.__initialized, "Thread.__init__() not called" + return self.__started.is_set() and not self.__stopped + + is_alive = isAlive + + def _daemon_getter(self): + """A boolean value indicating whether this thread is a daemon thread (True) or not (False). + + This must be set before start() is called, otherwise RuntimeError is + raised. Its initial value is inherited from the creating thread; the + main thread is not a daemon thread and therefore all threads created in + the main thread default to daemon = False. + + The entire Python program exits when no alive non-daemon threads are + left. + + """ + assert self.__initialized, "Thread.__init__() not called" + return self.__daemonic + + def _daemon_setter(self, daemonic): + if not self.__initialized: + raise RuntimeError("Thread.__init__() not called") + if self.__started.is_set(): + raise RuntimeError("cannot set daemon status of active thread"); + self.__daemonic = daemonic + + daemon = property(_daemon_getter, _daemon_setter) + + def isDaemon(self): + return self.daemon + + def setDaemon(self, daemonic): + self.daemon = daemonic + + def getName(self): + return self.name + + def setName(self, name): + self.name = name + +# The timer class was contributed by Itamar Shtull-Trauring + +def Timer(*args, **kwargs): + """Factory function to create a Timer object. + + Timers call a function after a specified number of seconds: + + t = Timer(30.0, f, args=[], kwargs={}) + t.start() + t.cancel() # stop the timer's action if it's still waiting + + """ + return _Timer(*args, **kwargs) + +class _Timer(Thread): + """Call a function after a specified number of seconds: + + t = Timer(30.0, f, args=[], kwargs={}) + t.start() + t.cancel() # stop the timer's action if it's still waiting + + """ + + def __init__(self, interval, function, args=[], kwargs={}): + Thread.__init__(self) + self.interval = interval + self.function = function + self.args = args + self.kwargs = kwargs + self.finished = Event() + + def cancel(self): + """Stop the timer if it hasn't finished yet""" + self.finished.set() + + def run(self): + self.finished.wait(self.interval) + if not self.finished.is_set(): + self.function(*self.args, **self.kwargs) + self.finished.set() + +# Special thread class to represent the main thread +# This is garbage collected through an exit handler + +class _MainThread(Thread): + + def __init__(self): + Thread.__init__(self, name="MainThread") + self.__started.set() + self._set_ident() + with _active_limbo_lock: + _active[_get_ident()] = self + + def _set_daemon(self): + return False + + def _exitfunc(self): + self.__stop() + t = _pickSomeNonDaemonThread() + if t: + if __debug__: + self._note("%s: waiting for other threads", self) + while t: + t.join() + t = _pickSomeNonDaemonThread() + if __debug__: + self._note("%s: exiting", self) + self.__delete() + +def _pickSomeNonDaemonThread(): + for t in enumerate(): + if not t.daemon and t.is_alive(): + return t + return None + + +# Dummy thread class to represent threads not started here. +# These aren't garbage collected when they die, nor can they be waited for. +# If they invoke anything in threading.py that calls current_thread(), they +# leave an entry in the _active dict forever after. +# Their purpose is to return *something* from current_thread(). +# They are marked as daemon threads so we won't wait for them +# when we exit (conform previous semantics). + +class _DummyThread(Thread): + + def __init__(self): + Thread.__init__(self, name=_newname("Dummy-%d")) + + # Thread.__block consumes an OS-level locking primitive, which + # can never be used by a _DummyThread. Since a _DummyThread + # instance is immortal, that's bad, so release this resource. + del self.__block + + self.__started.set() + self._set_ident() + with _active_limbo_lock: + _active[_get_ident()] = self + + def _set_daemon(self): + return True + + def join(self, timeout=None): + assert False, "cannot join a dummy thread" + + +# Global API functions + +def currentThread(): + """Return the current Thread object, corresponding to the caller's thread of control. + + If the caller's thread of control was not created through the threading + module, a dummy thread object with limited functionality is returned. + + """ + try: + return _active[_get_ident()] + except KeyError: + ##print "current_thread(): no current thread for", _get_ident() + return _DummyThread() + +current_thread = currentThread + +def activeCount(): + """Return the number of Thread objects currently alive. + + The returned count is equal to the length of the list returned by + enumerate(). + + """ + with _active_limbo_lock: + return len(_active) + len(_limbo) + +active_count = activeCount + +def _enumerate(): + # Same as enumerate(), but without the lock. Internal use only. + return _active.values() + _limbo.values() + +def enumerate(): + """Return a list of all Thread objects currently alive. + + The list includes daemonic threads, dummy thread objects created by + current_thread(), and the main thread. It excludes terminated threads and + threads that have not yet been started. + + """ + with _active_limbo_lock: + return _active.values() + _limbo.values() + +from thread import stack_size + +# Create the main thread object, +# and make it available for the interpreter +# (Py_Main) as threading._shutdown. + +_shutdown = _MainThread()._exitfunc + +# get thread-local implementation, either from the thread +# module, or from the python fallback + +# NOTE: Thread local classes follow: the Grumpy version of this file copies +# these from _threading_local.py to avoid circular dependency issues. + +class _localbase(object): + __slots__ = '_local__key', '_local__args', '_local__lock' + + def __new__(cls, *args, **kw): + self = object.__new__(cls) + key = '_local__key', 'thread.local.' + str(id(self)) + object.__setattr__(self, '_local__key', key) + object.__setattr__(self, '_local__args', (args, kw)) + object.__setattr__(self, '_local__lock', RLock()) + + if (args or kw) and (cls.__init__ is object.__init__): + raise TypeError("Initialization arguments are not supported") + + # We need to create the thread dict in anticipation of + # __init__ being called, to make sure we don't call it + # again ourselves. + dict = object.__getattribute__(self, '__dict__') + current_thread().__dict__[key] = dict + + return self + +def _patch(self): + key = object.__getattribute__(self, '_local__key') + d = current_thread().__dict__.get(key) + if d is None: + d = {} + current_thread().__dict__[key] = d + object.__setattr__(self, '__dict__', d) + + # we have a new instance dict, so call out __init__ if we have + # one + cls = type(self) + if cls.__init__ is not object.__init__: + args, kw = object.__getattribute__(self, '_local__args') + cls.__init__(self, *args, **kw) + else: + object.__setattr__(self, '__dict__', d) + +class local(_localbase): + + def __getattribute__(self, name): + lock = object.__getattribute__(self, '_local__lock') + lock.acquire() + try: + _patch(self) + return object.__getattribute__(self, name) + finally: + lock.release() + + def __setattr__(self, name, value): + if name == '__dict__': + raise AttributeError( + "%r object attribute '__dict__' is read-only" + % self.__class__.__name__) + lock = object.__getattribute__(self, '_local__lock') + lock.acquire() + try: + _patch(self) + return object.__setattr__(self, name, value) + finally: + lock.release() + + def __delattr__(self, name): + if name == '__dict__': + raise AttributeError( + "%r object attribute '__dict__' is read-only" + % self.__class__.__name__) + lock = object.__getattribute__(self, '_local__lock') + lock.acquire() + try: + _patch(self) + return object.__delattr__(self, name) + finally: + lock.release() + + def __del__(self): + key = object.__getattribute__(self, '_local__key') + + try: + # We use the non-locking API since we might already hold the lock + # (__del__ can be called at any point by the cyclic GC). + threads = _enumerate() + except: + # If enumerating the current threads fails, as it seems to do + # during shutdown, we'll skip cleanup under the assumption + # that there is nothing to clean up. + return + + for thread in threads: + try: + __dict__ = thread.__dict__ + except AttributeError: + # Thread is dying, rest in peace. + continue + + if key in __dict__: + try: + del __dict__[key] + except KeyError: + pass # didn't have anything in this thread + +# END _threading_local.py copy + +def _after_fork(): + # This function is called by Python/ceval.c:PyEval_ReInitThreads which + # is called from PyOS_AfterFork. Here we cleanup threading module state + # that should not exist after a fork. + + # Reset _active_limbo_lock, in case we forked while the lock was held + # by another (non-forked) thread. http://bugs.python.org/issue874900 + global _active_limbo_lock + _active_limbo_lock = _allocate_lock() + + # fork() only copied the current thread; clear references to others. + new_active = {} + current = current_thread() + with _active_limbo_lock: + for thread in _enumerate(): + # Any lock/condition variable may be currently locked or in an + # invalid state, so we reinitialize them. + if hasattr(thread, '_reset_internal_locks'): + thread._reset_internal_locks() + if thread is current: + # There is only one active thread. We reset the ident to + # its new value since it can have changed. + ident = _get_ident() + thread.__ident = ident + new_active[ident] = thread + else: + # All the others are already stopped. + thread.__stop() + + _limbo.clear() + _active.clear() + _active.update(new_active) + assert len(_active) == 1 + + +# Self-test code + +def _test(): + + class BoundedQueue(_Verbose): + + def __init__(self, limit): + _Verbose.__init__(self) + self.mon = RLock() + self.rc = Condition(self.mon) + self.wc = Condition(self.mon) + self.limit = limit + self.queue = _deque() + + def put(self, item): + self.mon.acquire() + while len(self.queue) >= self.limit: + self._note("put(%s): queue full", item) + self.wc.wait() + self.queue.append(item) + self._note("put(%s): appended, length now %d", + item, len(self.queue)) + self.rc.notify() + self.mon.release() + + def get(self): + self.mon.acquire() + while not self.queue: + self._note("get(): queue empty") + self.rc.wait() + item = self.queue.popleft() + self._note("get(): got %s, %d left", item, len(self.queue)) + self.wc.notify() + self.mon.release() + return item + + class ProducerThread(Thread): + + def __init__(self, queue, quota): + Thread.__init__(self, name="Producer") + self.queue = queue + self.quota = quota + + def run(self): + from random import random + counter = 0 + while counter < self.quota: + counter = counter + 1 + self.queue.put("%s.%d" % (self.name, counter)) + _sleep(random() * 0.00001) + + + class ConsumerThread(Thread): + + def __init__(self, queue, count): + Thread.__init__(self, name="Consumer") + self.queue = queue + self.count = count + + def run(self): + while self.count > 0: + item = self.queue.get() + print item + self.count = self.count - 1 + + NP = 3 + QL = 4 + NI = 5 + + Q = BoundedQueue(QL) + P = [] + for i in range(NP): + t = ProducerThread(Q, NI) + t.name = ("Producer-%d" % (i+1)) + P.append(t) + C = ConsumerThread(Q, NI*NP) + for t in P: + t.start() + _sleep(0.000001) + C.start() + for t in P: + t.join() + C.join() + +if __name__ == '__main__': + _test() diff --git a/third_party/stdlib/types.py b/third_party/stdlib/types.py index a2072d87..65e1fb62 100644 --- a/third_party/stdlib/types.py +++ b/third_party/stdlib/types.py @@ -72,7 +72,7 @@ def _m(self): pass del tb SliceType = slice -#EllipsisType = type(Ellipsis) +EllipsisType = type(Ellipsis) #DictProxyType = type(TypeType.__dict__) NotImplementedType = type(NotImplemented) diff --git a/third_party/stdlib/uu.py b/third_party/stdlib/uu.py new file mode 100644 index 00000000..f8fa4c47 --- /dev/null +++ b/third_party/stdlib/uu.py @@ -0,0 +1,196 @@ +#! /usr/bin/env python + +# Copyright 1994 by Lance Ellinghouse +# Cathedral City, California Republic, United States of America. +# All Rights Reserved +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and that +# both that copyright notice and this permission notice appear in +# supporting documentation, and that the name of Lance Ellinghouse +# not be used in advertising or publicity pertaining to distribution +# of the software without specific, written prior permission. +# LANCE ELLINGHOUSE DISCLAIMS ALL WARRANTIES WITH REGARD TO +# THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS, IN NO EVENT SHALL LANCE ELLINGHOUSE CENTRUM BE LIABLE +# FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# +# Modified by Jack Jansen, CWI, July 1995: +# - Use binascii module to do the actual line-by-line conversion +# between ascii and binary. This results in a 1000-fold speedup. The C +# version is still 5 times faster, though. +# - Arguments more compliant with python standard + +"""Implementation of the UUencode and UUdecode functions. + +encode(in_file, out_file [,name, mode]) +decode(in_file [, out_file, mode]) +""" + +import binascii +import os +import sys + +__all__ = ["Error", "encode", "decode"] + +class Error(Exception): + pass + +def encode(in_file, out_file, name=None, mode=None): + """Uuencode file""" + # + # If in_file is a pathname open it and change defaults + # + opened_files = [] + try: + if in_file == '-': + in_file = sys.stdin + elif isinstance(in_file, basestring): + if name is None: + name = os.path.basename(in_file) + if mode is None: + try: + mode = os.stat(in_file).st_mode + except AttributeError: + pass + in_file = open(in_file, 'rb') + opened_files.append(in_file) + # + # Open out_file if it is a pathname + # + if out_file == '-': + out_file = sys.stdout + elif isinstance(out_file, basestring): + out_file = open(out_file, 'wb') + opened_files.append(out_file) + # + # Set defaults for name and mode + # + if name is None: + name = '-' + if mode is None: + mode = 0666 + # + # Write the data + # + out_file.write('begin %o %s\n' % ((mode&0777),name)) + data = in_file.read(45) + while len(data) > 0: + out_file.write(binascii.b2a_uu(data)) + data = in_file.read(45) + out_file.write(' \nend\n') + finally: + for f in opened_files: + f.close() + + +def decode(in_file, out_file=None, mode=None, quiet=0): + """Decode uuencoded file""" + # + # Open the input file, if needed. + # + opened_files = [] + if in_file == '-': + in_file = sys.stdin + elif isinstance(in_file, basestring): + in_file = open(in_file) + opened_files.append(in_file) + try: + # + # Read until a begin is encountered or we've exhausted the file + # + while True: + hdr = in_file.readline() + if not hdr: + raise Error('No valid begin line found in input file') + if not hdr.startswith('begin'): + continue + hdrfields = hdr.split(' ', 2) + if len(hdrfields) == 3 and hdrfields[0] == 'begin': + try: + int(hdrfields[1], 8) + break + except ValueError: + pass + if out_file is None: + out_file = hdrfields[2].rstrip() + if os.path.exists(out_file): + raise Error('Cannot overwrite existing file: %s' % out_file) + if mode is None: + mode = int(hdrfields[1], 8) + # + # Open the output file + # + if out_file == '-': + out_file = sys.stdout + elif isinstance(out_file, basestring): + fp = open(out_file, 'wb') + try: + os.path.chmod(out_file, mode) + except AttributeError: + pass + out_file = fp + opened_files.append(out_file) + # + # Main decoding loop + # + s = in_file.readline() + while s and s.strip() != 'end': + try: + data = binascii.a2b_uu(s) + except binascii.Error, v: + # Workaround for broken uuencoders by /Fredrik Lundh + nbytes = (((ord(s[0])-32) & 63) * 4 + 5) // 3 + data = binascii.a2b_uu(s[:nbytes]) + if not quiet: + sys.stderr.write("Warning: %s\n" % v) + out_file.write(data) + s = in_file.readline() + if not s: + raise Error('Truncated input file') + finally: + for f in opened_files: + f.close() + +def test(): + """uuencode/uudecode main program""" + + import optparse + parser = optparse.OptionParser(usage='usage: %prog [-d] [-t] [input [output]]') + parser.add_option('-d', '--decode', dest='decode', help='Decode (instead of encode)?', default=False, action='store_true') + parser.add_option('-t', '--text', dest='text', help='data is text, encoded format unix-compatible text?', default=False, action='store_true') + + (options, args) = parser.parse_args() + if len(args) > 2: + parser.error('incorrect number of arguments') + sys.exit(1) + + input = sys.stdin + output = sys.stdout + if len(args) > 0: + input = args[0] + if len(args) > 1: + output = args[1] + + if options.decode: + if options.text: + if isinstance(output, basestring): + output = open(output, 'w') + else: + print sys.argv[0], ': cannot do -t to stdout' + sys.exit(1) + decode(input, output) + else: + if options.text: + if isinstance(input, basestring): + input = open(input, 'r') + else: + print sys.argv[0], ': cannot do -t from stdin' + sys.exit(1) + encode(input, output) + +if __name__ == '__main__': + test() diff --git a/third_party/stdlib/weakref.py b/third_party/stdlib/weakref.py index 90d9196b..7112201c 100644 --- a/third_party/stdlib/weakref.py +++ b/third_party/stdlib/weakref.py @@ -20,7 +20,7 @@ # ProxyType, # ReferenceType) -from __go__.grumpy import WeakRefType as ReferenceType +from '__go__/grumpy' import WeakRefType as ReferenceType ref = ReferenceType import _weakrefset diff --git a/tools/genmake b/tools/genmake new file mode 100755 index 00000000..c13465a5 --- /dev/null +++ b/tools/genmake @@ -0,0 +1,87 @@ +#!/usr/bin/env python + +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate a Makefile for Python targets in a GOPATH directory.""" + +import argparse +import os +import subprocess +import sys + + +parser = argparse.ArgumentParser() +parser.add_argument('dir', help='GOPATH dir to scan for Python modules') +parser.add_argument('-all_target', default='all', + help='make target that will build all modules') + + +def _PrintRule(target, prereqs, rules): + print '{}: {}'.format(target, ' '.join(prereqs)) + if rules: + print '\t@mkdir -p $(@D)' + for rule in rules: + print '\t@{}'.format(rule) + print + + +def main(args): + try: + proc = subprocess.Popen('go env GOOS GOARCH', shell=True, + stdout=subprocess.PIPE) + except OSError as e: + print >> sys.stderr, str(e) + return 1 + out, _ = proc.communicate() + if proc.returncode: + print >> sys.stderr, 'go exited with status: {}'.format(proc.returncode) + return 1 + goos, goarch = out.split() + + if args.all_target: + print '{}:\n'.format(args.all_target) + + gopath = os.path.normpath(args.dir) + pkg_dir = os.path.join(gopath, 'pkg', '{}_{}'.format(goos, goarch)) + pydir = os.path.join(gopath, 'src', '__python__') + for dirpath, _, filenames in os.walk(pydir): + for filename in filenames: + if not filename.endswith('.py'): + continue + basename = os.path.relpath(dirpath, pydir) + if filename != '__init__.py': + basename = os.path.normpath( + os.path.join(basename, filename[:-3])) + modname = basename.replace(os.sep, '.') + ar_name = os.path.join(pkg_dir, '__python__', basename + '.a') + go_file = os.path.join(pydir, basename, 'module.go') + _PrintRule(go_file, + [os.path.join(dirpath, filename)], + ['grumpc -modname={} $< > $@'.format(modname)]) + recipe = (r"""pydeps -modname=%s $< | awk '{gsub(/\./, "/", $$0); """ + r"""print "%s: %s/__python__/" $$0 ".a"}' > $@""") + dep_file = os.path.join(pydir, basename, 'module.d') + _PrintRule(dep_file, [os.path.join(dirpath, filename)], + [recipe % (modname, ar_name, pkg_dir)]) + go_package = '__python__/' + basename.replace(os.sep, '/') + recipe = 'go tool compile -o $@ -p {} -complete -I {} -pack $<' + _PrintRule(ar_name, [go_file], [recipe.format(go_package, pkg_dir)]) + if args.all_target: + _PrintRule(args.all_target, [ar_name], []) + print '-include {}\n'.format(dep_file) + + +if __name__ == '__main__': + sys.exit(main(parser.parse_args())) diff --git a/tools/grumpc b/tools/grumpc index 91b36daa..53837738 100755 --- a/tools/grumpc +++ b/tools/grumpc @@ -17,34 +17,40 @@ """A Python -> Go transcompiler.""" +from __future__ import unicode_literals + import argparse -import ast +import os import sys import textwrap from grumpy.compiler import block +from grumpy.compiler import imputil from grumpy.compiler import stmt from grumpy.compiler import util +from grumpy import pythonparser parser = argparse.ArgumentParser() -parser.add_argument('filename', help='Python source filename') +parser.add_argument('script', help='Python source filename') parser.add_argument('-modname', default='__main__', help='Python module name') -parser.add_argument('-runtime', default='grumpy', - help='Grumpy runtime package name') -parser.add_argument('-libroot', default='grumpy/lib', - help='Path where Grumpy standard library packages live') def main(args): - for arg in ('filename', 'modname', 'runtime', 'libroot'): + for arg in ('script', 'modname'): if not getattr(args, arg, None): print >> sys.stderr, '{} arg must not be empty'.format(arg) return 1 - with open(args.filename) as py_file: + + gopath = os.getenv('GOPATH', None) + if not gopath: + print >> sys.stderr, 'GOPATH not set' + return 1 + + with open(args.script) as py_file: py_contents = py_file.read() try: - mod = ast.parse(py_contents) + mod = pythonparser.parse(py_contents) except SyntaxError as e: print >> sys.stderr, '{}: line {}: invalid syntax: {}'.format( e.filename, e.lineno, e.text) @@ -52,17 +58,18 @@ def main(args): # Do a pass for compiler directives from `from __future__ import *` statements try: - future_features = stmt.visit_future(mod) - except util.ParseError as e: + future_node, future_features = imputil.parse_future_features(mod) + except util.CompileError as e: print >> sys.stderr, str(e) return 2 + importer = imputil.Importer(gopath, args.modname, args.script, + future_features.absolute_import) full_package_name = args.modname.replace('.', '/') - mod_block = block.ModuleBlock(full_package_name, args.runtime, args.libroot, - args.filename, py_contents.split('\n'), - future_features) - mod_block.add_native_import('grumpy') - visitor = stmt.StatementVisitor(mod_block) + mod_block = block.ModuleBlock(importer, full_package_name, args.script, + py_contents, future_features) + + visitor = stmt.StatementVisitor(mod_block, future_node) # Indent so that the module body is aligned with the goto labels. with visitor.writer.indent_block(): try: @@ -71,39 +78,27 @@ def main(args): print >> sys.stderr, str(e) return 2 - imports = dict(mod_block.imports) - has_main = args.modname == '__main__' - if has_main: - imports['os'] = block.Package('os') - writer = util.Writer(sys.stdout) - package_name = args.modname.split('.')[-1] - if has_main: - package_name = 'main' - writer.write('package {}'.format(package_name)) - writer.write_import_block(imports) - - writer.write('func initModule(πF *πg.Frame, ' - '_ []*πg.Object) (*πg.Object, *πg.BaseException) {') - with writer.indent_block(): + tmpl = textwrap.dedent("""\ + package $package + import πg "grumpy" + var Code *πg.Code + func init() { + \tCode = πg.NewCode("", $script, nil, 0, func(πF *πg.Frame, _ []*πg.Object) (*πg.Object, *πg.BaseException) { + \t\tvar πR *πg.Object; _ = πR + \t\tvar πE *πg.BaseException; _ = πE""") + writer.write_tmpl(tmpl, package=args.modname.split('.')[-1], + script=util.go_str(args.script)) + with writer.indent_block(2): for s in sorted(mod_block.strings): writer.write('ß{} := πg.InternStr({})'.format(s, util.go_str(s))) writer.write_temp_decls(mod_block) - writer.write_block(mod_block, visitor.writer.out.getvalue()) - writer.write('}') - writer.write('var Code *πg.Code') - - if has_main: - writer.write_tmpl(textwrap.dedent("""\ - func main() { - \tCode = πg.NewCode("", $filename, nil, 0, initModule) - \tπ_os.Exit(πg.RunMain(Code)) - }"""), filename=util.go_str(args.filename)) - else: - writer.write_tmpl(textwrap.dedent("""\ - func init() { - \tCode = πg.NewCode("", $filename, nil, 0, initModule) - }"""), filename=util.go_str(args.filename)) + writer.write_block(mod_block, visitor.writer.getvalue()) + writer.write_tmpl(textwrap.dedent("""\ + \t\treturn nil, πE + \t}) + \tπg.RegisterModule($modname, Code) + }"""), modname=util.go_str(args.modname)) return 0 diff --git a/tools/grumprun b/tools/grumprun index 8efdc173..fd237713 100755 --- a/tools/grumprun +++ b/tools/grumprun @@ -22,46 +22,90 @@ Usage: $ grumprun -m # Run the named module. import argparse import os +import random +import shutil import string import subprocess import sys import tempfile +from grumpy.compiler import imputil + parser = argparse.ArgumentParser() -parser.add_argument('-m', '--module', help='Run the named module') +parser.add_argument('-m', '--modname', help='Run the named module') module_tmpl = string.Template("""\ package main import ( \t"os" \t"grumpy" -\t"grumpy/lib/$package" +\tmod "$package" +$imports ) func main() { -\tos.Exit(grumpy.RunMain($alias.Code)) +\tgrumpy.ImportModule(grumpy.NewRootFrame(), "traceback") +\tos.Exit(grumpy.RunMain(mod.Code)) } """) def main(args): + gopath = os.getenv('GOPATH', None) + if not gopath: + print >> sys.stderr, 'GOPATH not set' + return 1 + + modname = args.modname + workdir = tempfile.mkdtemp() try: - fd, path = tempfile.mkstemp(suffix='.go') - if args.module: - with os.fdopen(fd, 'w') as f: - package = args.module.replace('.', '/') - alias = package.split('/')[-1] - f.write(module_tmpl.substitute(package=package, alias=alias)) + if modname: + # Find the script associated with the given module. + for d in gopath.split(os.pathsep): + script = imputil.find_script( + os.path.join(d, 'src', '__python__'), modname) + if script: + break + else: + print >> sys.stderr, "can't find module", modname + return 1 else: + # Generate a dummy python script on the GOPATH. + modname = ''.join(random.choice(string.ascii_letters) for _ in range(16)) + py_dir = os.path.join(workdir, 'src', '__python__') + mod_dir = os.path.join(py_dir, modname) + os.makedirs(mod_dir) + script = os.path.join(py_dir, 'module.py') + with open(script, 'w') as f: + f.write(sys.stdin.read()) + gopath = gopath + os.pathsep + workdir + os.putenv('GOPATH', gopath) + # Compile the dummy script to Go using grumpc. + fd = os.open(os.path.join(mod_dir, 'module.go'), os.O_WRONLY | os.O_CREAT) try: - p = subprocess.Popen('grumpc /dev/stdin', stdout=fd, shell=True) + p = subprocess.Popen('grumpc ' + script, stdout=fd, shell=True) if p.wait(): return 1 finally: os.close(fd) - return subprocess.Popen('go run ' + path, shell=True).wait() + + names = imputil.calculate_transitive_deps(modname, script, gopath) + # Make sure traceback is available in all Python binaries. + names.add('traceback') + go_main = os.path.join(workdir, 'main.go') + package = _package_name(modname) + imports = ''.join('\t_ "' + _package_name(name) + '"\n' for name in names) + with open(go_main, 'w') as f: + f.write(module_tmpl.substitute(package=package, imports=imports)) + return subprocess.Popen('go run ' + go_main, shell=True).wait() finally: - os.remove(path) + shutil.rmtree(workdir) + + +def _package_name(modname): + if modname.startswith('__go__/'): + return '__python__/' + modname + return '__python__/' + modname.replace('.', '/') if __name__ == '__main__': diff --git a/tools/pkgc.go b/tools/pkgc.go new file mode 100644 index 00000000..4b08aac5 --- /dev/null +++ b/tools/pkgc.go @@ -0,0 +1,114 @@ +// pkgc is a tool for generating wrappers for Go packages imported by Grumpy +// programs. +// +// usage: pkgc PACKAGE +// +// Where PACKAGE is the full Go package name. Generated code is dumped to +// stdout. Packages generated in this way can be imported by Grumpy programs +// using string literal import syntax, e.g.: +// +// import "__go__/encoding/json" +// +// Or: +// +// from "__go__/time" import Duration + +package main + +import ( + "bytes" + "fmt" + "go/constant" + "go/importer" + "go/types" + "math" + "os" + "path" +) + +const packageTemplate = `package %[1]s +import ( + "grumpy" + "reflect" + mod %[2]q +) +func fun(f *grumpy.Frame, _ []*grumpy.Object) (*grumpy.Object, *grumpy.BaseException) { +%[3]s + return nil, nil +} +var Code = grumpy.NewCode("", %[2]q, nil, 0, fun) +func init() { + grumpy.RegisterModule("__go__/%[2]s", Code) +} +` + +const typeTemplate = ` if true { + var x mod.%[1]s + if o, raised := grumpy.WrapNative(f, reflect.ValueOf(x)); raised != nil { + return nil, raised + } else if raised = f.Globals().SetItemString(f, %[1]q, o.Type().ToObject()); raised != nil { + return nil, raised + } + } +` + +const varTemplate = ` if o, raised := grumpy.WrapNative(f, reflect.ValueOf(%[1]s)); raised != nil { + return nil, raised + } else if raised = f.Globals().SetItemString(f, %[2]q, o); raised != nil { + return nil, raised + } +` + +func getConst(name string, v constant.Value) string { + format := "%s" + switch v.Kind() { + case constant.Int: + if constant.Sign(v) >= 0 { + if i, exact := constant.Uint64Val(v); exact { + if i > math.MaxInt64 { + format = "uint64(%s)" + } + } else { + format = "float64(%s)" + } + } + case constant.Float: + format = "float64(%s)" + } + return fmt.Sprintf(format, name) +} + +func main() { + if len(os.Args) != 2 { + fmt.Fprint(os.Stderr, "usage: pkgc PACKAGE") + os.Exit(1) + } + pkgPath := os.Args[1] + pkg, err := importer.Default().Import(pkgPath) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to import: %q: %v\n", pkgPath, err) + os.Exit(2) + } + var buf bytes.Buffer + scope := pkg.Scope() + for _, name := range scope.Names() { + o := scope.Lookup(name) + if !o.Exported() { + continue + } + switch x := o.(type) { + case *types.TypeName: + if types.IsInterface(x.Type()) { + continue + } + buf.WriteString(fmt.Sprintf(typeTemplate, name)) + case *types.Const: + expr := getConst("mod." + name, x.Val()) + buf.WriteString(fmt.Sprintf(varTemplate, expr, name)) + default: + expr := "mod." + name + buf.WriteString(fmt.Sprintf(varTemplate, expr, name)) + } + } + fmt.Printf(packageTemplate, path.Base(pkgPath), pkgPath, buf.Bytes()) +} diff --git a/tools/pydeps b/tools/pydeps new file mode 100755 index 00000000..f0dd45f7 --- /dev/null +++ b/tools/pydeps @@ -0,0 +1,63 @@ +#!/usr/bin/env python + +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Outputs names of modules imported by a script.""" + +import argparse +import os +import sys + +from grumpy.compiler import imputil +from grumpy.compiler import util + + +parser = argparse.ArgumentParser() +parser.add_argument('script', help='Python source filename') +parser.add_argument('-modname', default='__main__', help='Python module name') + + +def main(args): + gopath = os.getenv('GOPATH', None) + if not gopath: + print >> sys.stderr, 'GOPATH not set' + return 1 + + try: + imports = imputil.collect_imports(args.modname, args.script, gopath) + except SyntaxError as e: + print >> sys.stderr, '{}: line {}: invalid syntax: {}'.format( + e.filename, e.lineno, e.text) + return 2 + except util.CompileError as e: + print >> sys.stderr, str(e) + return 2 + + names = set([args.modname]) + for imp in imports: + if imp.is_native: + print imp.name + else: + parts = imp.name.split('.') + # Iterate over all packages and the leaf module. + for i in xrange(len(parts)): + name = '.'.join(parts[:i+1]) + if name not in names: + names.add(name) + print name + + +if __name__ == '__main__': + main(parser.parse_args())