bind: fix funcs with params with pointers

This CL introduces skBuiltin to differentiate between named types pointing at a
basic type and builtin types.
It also provides some limited heuristics to detect when a parameters is passed
by reference.

Updates #45

Change-Id: I0863303772514f819131d4bcf586358d0cc707db
This commit is contained in:
Sebastien Binet 2015-08-14 13:10:03 +02:00
parent 33f3ca06ba
commit 44026ff053
4 changed files with 81 additions and 39 deletions

View File

@ -264,7 +264,7 @@ func (g *cpyGen) gen() error {
for _, n := range g.pkg.syms.names() {
sym := g.pkg.syms.sym(n)
if !sym.isType() {
if !sym.isType() || sym.isBuiltin() {
continue
}
g.impl.Printf(
@ -280,7 +280,7 @@ func (g *cpyGen) gen() error {
for _, n := range g.pkg.syms.names() {
sym := g.pkg.syms.sym(n)
if !sym.isType() {
if !sym.isType() || sym.isBuiltin() {
continue
}
g.impl.Printf("Py_INCREF(&%sType);\n", sym.cpyname)

View File

@ -21,6 +21,10 @@ func (g *cpyGen) genType(sym *symbol) {
if sym.isBasic() && !sym.isNamed() {
return
}
_, isptr := sym.GoType().(*types.Pointer)
if isptr {
return
}
g.decl.Printf("\n/* --- decls for type %v --- */\n", sym.gofmt())
if sym.isBasic() {
@ -368,8 +372,7 @@ func (g *cpyGen) genTypeMembers(sym *symbol) {
func (g *cpyGen) genTypeMethods(sym *symbol) {
g.decl.Printf("\n/* methods for %s */\n", sym.gofmt())
if sym.isNamed() {
typ := sym.GoType().(*types.Named)
if typ, ok := sym.GoType().(*types.Named); ok {
for imeth := 0; imeth < typ.NumMethods(); imeth++ {
m := typ.Method(imeth)
if !m.Exported() {
@ -390,8 +393,7 @@ func (g *cpyGen) genTypeMethods(sym *symbol) {
g.impl.Printf("\n/* methods for %s */\n", sym.gofmt())
g.impl.Printf("static PyMethodDef %s_methods[] = {\n", sym.cpyname)
g.impl.Indent()
if sym.isNamed() {
typ := sym.GoType().(*types.Named)
if typ, ok := sym.GoType().(*types.Named); ok {
for imeth := 0; imeth < typ.NumMethods(); imeth++ {
m := typ.Method(imeth)
if !m.Exported() {

View File

@ -216,6 +216,7 @@ func cgo_func_%[1]s%[4]v%[5]v{
func (g *goGen) genFuncBody(f Func) {
sig := f.Signature()
tsig := f.GoType().Underlying().(*types.Signature)
results := sig.Results()
for i := range results {
if i > 0 {
@ -237,14 +238,43 @@ func (g *goGen) genFuncBody(f Func) {
}
head := arg.Name()
if arg.needWrap() {
head = fmt.Sprintf(
"*(*%s)(unsafe.Pointer(%s))",
types.TypeString(
arg.GoType(),
func(*types.Package) string { return g.pkg.Name() },
),
arg.Name(),
)
switch arg.GoType().(type) {
case *types.Pointer:
head = fmt.Sprintf(
"(%s)(unsafe.Pointer(%s))",
types.TypeString(
arg.GoType(),
func(*types.Package) string { return g.pkg.Name() },
),
arg.Name(),
)
default:
head = fmt.Sprintf(
"*(*%s)(unsafe.Pointer(%s))",
types.TypeString(
arg.GoType(),
func(*types.Package) string { return g.pkg.Name() },
),
arg.Name(),
)
}
} else {
targ := tsig.Params().At(i)
switch targ.Type().(type) {
case *types.Pointer:
head = "&" + head + fmt.Sprintf(
" /* kind=%v */",
arg.sym.kind,
)
if arg.sym.isBuiltin() {
head = "&" + arg.Name()
} else {
head = fmt.Sprintf("(%s)(unsafe.Pointer(&%s))",
arg.sym.gofmt(),
arg.Name(),
)
}
}
}
g.Printf("%s%s", head, tail)
}
@ -529,6 +559,11 @@ func (g *goGen) genType(sym *symbol) {
return
}
_, isptr := sym.GoType().(*types.Pointer)
if isptr {
return
}
g.Printf("\n// --- wrapping %s ---\n\n", sym.gofmt())
g.Printf("//export %[1]s\n", sym.cgoname)
g.Printf("// %[1]s wraps %[2]s\n", sym.cgoname, sym.gofmt())
@ -813,11 +848,10 @@ func (g *goGen) genTypeTPCall(sym *symbol) {
}
func (g *goGen) genTypeMethods(sym *symbol) {
if !sym.isNamed() {
typ, ok := sym.GoType().(*types.Named)
if !ok {
return
}
typ := sym.GoType().(*types.Named)
for imeth := 0; imeth < typ.NumMethods(); imeth++ {
m := typ.Method(imeth)
if !m.Exported() {

View File

@ -34,6 +34,7 @@ const (
skType
skArray
skBasic
skBuiltin
skInterface
skMap
skNamed
@ -52,6 +53,7 @@ var (
"type": skType,
"array": skArray,
"basic": skBasic,
"builtin": skBuiltin,
"interface": skInterface,
"map": skMap,
"named": skNamed,
@ -111,6 +113,10 @@ func (s symbol) isBasic() bool {
return (s.kind & skBasic) != 0
}
func (s symbol) isBuiltin() bool {
return (s.kind & skBuiltin) != 0
}
func (s symbol) isArray() bool {
return (s.kind & skArray) != 0
}
@ -393,7 +399,7 @@ func (sym *symtab) addType(obj types.Object, t types.Type) {
goname: n,
cgoname: "cgo_type_" + id,
cpyname: "cpy_type_" + id,
pyfmt: bsym.pyfmt,
pyfmt: "O",
pybuf: bsym.pybuf,
pysig: "object",
c2py: "cgopy_cnv_c2py_" + id,
@ -730,7 +736,7 @@ func init() {
gopkg: look("bool").Pkg(),
goobj: look("bool"),
gotyp: look("bool").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "bool",
cgoname: "GoUint8",
cpyname: "GoUint8",
@ -746,7 +752,7 @@ func init() {
gopkg: look("byte").Pkg(),
goobj: look("byte"),
gotyp: look("byte").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "byte",
cpyname: "uint8_t",
cgoname: "GoUint8",
@ -760,7 +766,7 @@ func init() {
gopkg: look("int").Pkg(),
goobj: look("int"),
gotyp: look("int").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "int",
cpyname: "int",
cgoname: "GoInt",
@ -776,7 +782,7 @@ func init() {
gopkg: look("int8").Pkg(),
goobj: look("int8"),
gotyp: look("int8").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "int8",
cpyname: "int8_t",
cgoname: "GoInt8",
@ -792,7 +798,7 @@ func init() {
gopkg: look("int16").Pkg(),
goobj: look("int16"),
gotyp: look("int16").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "int16",
cpyname: "int16_t",
cgoname: "GoInt16",
@ -808,7 +814,7 @@ func init() {
gopkg: look("int32").Pkg(),
goobj: look("int32"),
gotyp: look("int32").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "int32",
cpyname: "int32_t",
cgoname: "GoInt32",
@ -824,7 +830,7 @@ func init() {
gopkg: look("int64").Pkg(),
goobj: look("int64"),
gotyp: look("int64").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "int64",
cpyname: "int64_t",
cgoname: "GoInt64",
@ -840,7 +846,7 @@ func init() {
gopkg: look("uint").Pkg(),
goobj: look("uint"),
gotyp: look("uint").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "uint",
cpyname: "unsigned int",
cgoname: "GoUint",
@ -856,7 +862,7 @@ func init() {
gopkg: look("uint8").Pkg(),
goobj: look("uint8"),
gotyp: look("uint8").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "uint8",
cpyname: "uint8_t",
cgoname: "GoUint8",
@ -872,7 +878,7 @@ func init() {
gopkg: look("uint16").Pkg(),
goobj: look("uint16"),
gotyp: look("uint16").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "uint16",
cpyname: "uint16_t",
cgoname: "GoUint16",
@ -888,7 +894,7 @@ func init() {
gopkg: look("uint32").Pkg(),
goobj: look("uint32"),
gotyp: look("uint32").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "uint32",
cpyname: "uint32_t",
cgoname: "GoUint32",
@ -904,7 +910,7 @@ func init() {
gopkg: look("uint64").Pkg(),
goobj: look("uint64"),
gotyp: look("uint64").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "uint64",
cpyname: "uint64_t",
cgoname: "GoUint64",
@ -920,7 +926,7 @@ func init() {
gopkg: look("float32").Pkg(),
goobj: look("float32"),
gotyp: look("float32").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "float32",
cpyname: "float",
cgoname: "GoFloat32",
@ -936,7 +942,7 @@ func init() {
gopkg: look("float64").Pkg(),
goobj: look("float64"),
gotyp: look("float64").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "float64",
cpyname: "double",
cgoname: "GoFloat64",
@ -952,7 +958,7 @@ func init() {
gopkg: look("complex64").Pkg(),
goobj: look("complex64"),
gotyp: look("complex64").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "complex64",
cpyname: "float complex",
cgoname: "GoComplex64",
@ -968,7 +974,7 @@ func init() {
gopkg: look("complex128").Pkg(),
goobj: look("complex128"),
gotyp: look("complex128").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "complex128",
cpyname: "double complex",
cgoname: "GoComplex128",
@ -984,7 +990,7 @@ func init() {
gopkg: look("string").Pkg(),
goobj: look("string"),
gotyp: look("string").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "string",
cpyname: "GoString",
cgoname: "GoString",
@ -1000,7 +1006,7 @@ func init() {
gopkg: look("rune").Pkg(),
goobj: look("rune"),
gotyp: look("rune").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "rune",
cpyname: "GoRune",
cgoname: "GoRune",
@ -1016,7 +1022,7 @@ func init() {
gopkg: look("error").Pkg(),
goobj: look("error"),
gotyp: look("error").Type(),
kind: skType | skInterface,
kind: skType | skInterface | skBuiltin,
goname: "error",
cgoname: "GoInterface",
cpyname: "GoInterface",
@ -1033,7 +1039,7 @@ func init() {
gopkg: look("int").Pkg(),
goobj: look("int"),
gotyp: look("int").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "int",
cpyname: "int64_t",
cgoname: "GoInt",
@ -1048,7 +1054,7 @@ func init() {
gopkg: look("uint").Pkg(),
goobj: look("uint"),
gotyp: look("uint").Type(),
kind: skType | skBasic,
kind: skType | skBasic | skBuiltin,
goname: "uint",
cpyname: "uint64_t",
cgoname: "GoUint",