bind: fix funcs with params with pointers

This CL introduces skBuiltin to differentiate between named types pointing at a
basic type and builtin types.
It also provides some limited heuristics to detect when a parameters is passed
by reference.

Updates #45

Change-Id: I0863303772514f819131d4bcf586358d0cc707db
This commit is contained in:
Sebastien Binet 2015-08-14 13:10:03 +02:00
parent 33f3ca06ba
commit 44026ff053
4 changed files with 81 additions and 39 deletions

View File

@ -264,7 +264,7 @@ func (g *cpyGen) gen() error {
for _, n := range g.pkg.syms.names() { for _, n := range g.pkg.syms.names() {
sym := g.pkg.syms.sym(n) sym := g.pkg.syms.sym(n)
if !sym.isType() { if !sym.isType() || sym.isBuiltin() {
continue continue
} }
g.impl.Printf( g.impl.Printf(
@ -280,7 +280,7 @@ func (g *cpyGen) gen() error {
for _, n := range g.pkg.syms.names() { for _, n := range g.pkg.syms.names() {
sym := g.pkg.syms.sym(n) sym := g.pkg.syms.sym(n)
if !sym.isType() { if !sym.isType() || sym.isBuiltin() {
continue continue
} }
g.impl.Printf("Py_INCREF(&%sType);\n", sym.cpyname) g.impl.Printf("Py_INCREF(&%sType);\n", sym.cpyname)

View File

@ -21,6 +21,10 @@ func (g *cpyGen) genType(sym *symbol) {
if sym.isBasic() && !sym.isNamed() { if sym.isBasic() && !sym.isNamed() {
return return
} }
_, isptr := sym.GoType().(*types.Pointer)
if isptr {
return
}
g.decl.Printf("\n/* --- decls for type %v --- */\n", sym.gofmt()) g.decl.Printf("\n/* --- decls for type %v --- */\n", sym.gofmt())
if sym.isBasic() { if sym.isBasic() {
@ -368,8 +372,7 @@ func (g *cpyGen) genTypeMembers(sym *symbol) {
func (g *cpyGen) genTypeMethods(sym *symbol) { func (g *cpyGen) genTypeMethods(sym *symbol) {
g.decl.Printf("\n/* methods for %s */\n", sym.gofmt()) g.decl.Printf("\n/* methods for %s */\n", sym.gofmt())
if sym.isNamed() { if typ, ok := sym.GoType().(*types.Named); ok {
typ := sym.GoType().(*types.Named)
for imeth := 0; imeth < typ.NumMethods(); imeth++ { for imeth := 0; imeth < typ.NumMethods(); imeth++ {
m := typ.Method(imeth) m := typ.Method(imeth)
if !m.Exported() { if !m.Exported() {
@ -390,8 +393,7 @@ func (g *cpyGen) genTypeMethods(sym *symbol) {
g.impl.Printf("\n/* methods for %s */\n", sym.gofmt()) g.impl.Printf("\n/* methods for %s */\n", sym.gofmt())
g.impl.Printf("static PyMethodDef %s_methods[] = {\n", sym.cpyname) g.impl.Printf("static PyMethodDef %s_methods[] = {\n", sym.cpyname)
g.impl.Indent() g.impl.Indent()
if sym.isNamed() { if typ, ok := sym.GoType().(*types.Named); ok {
typ := sym.GoType().(*types.Named)
for imeth := 0; imeth < typ.NumMethods(); imeth++ { for imeth := 0; imeth < typ.NumMethods(); imeth++ {
m := typ.Method(imeth) m := typ.Method(imeth)
if !m.Exported() { if !m.Exported() {

View File

@ -216,6 +216,7 @@ func cgo_func_%[1]s%[4]v%[5]v{
func (g *goGen) genFuncBody(f Func) { func (g *goGen) genFuncBody(f Func) {
sig := f.Signature() sig := f.Signature()
tsig := f.GoType().Underlying().(*types.Signature)
results := sig.Results() results := sig.Results()
for i := range results { for i := range results {
if i > 0 { if i > 0 {
@ -237,14 +238,43 @@ func (g *goGen) genFuncBody(f Func) {
} }
head := arg.Name() head := arg.Name()
if arg.needWrap() { if arg.needWrap() {
head = fmt.Sprintf( switch arg.GoType().(type) {
"*(*%s)(unsafe.Pointer(%s))", case *types.Pointer:
types.TypeString( head = fmt.Sprintf(
arg.GoType(), "(%s)(unsafe.Pointer(%s))",
func(*types.Package) string { return g.pkg.Name() }, types.TypeString(
), arg.GoType(),
arg.Name(), func(*types.Package) string { return g.pkg.Name() },
) ),
arg.Name(),
)
default:
head = fmt.Sprintf(
"*(*%s)(unsafe.Pointer(%s))",
types.TypeString(
arg.GoType(),
func(*types.Package) string { return g.pkg.Name() },
),
arg.Name(),
)
}
} else {
targ := tsig.Params().At(i)
switch targ.Type().(type) {
case *types.Pointer:
head = "&" + head + fmt.Sprintf(
" /* kind=%v */",
arg.sym.kind,
)
if arg.sym.isBuiltin() {
head = "&" + arg.Name()
} else {
head = fmt.Sprintf("(%s)(unsafe.Pointer(&%s))",
arg.sym.gofmt(),
arg.Name(),
)
}
}
} }
g.Printf("%s%s", head, tail) g.Printf("%s%s", head, tail)
} }
@ -529,6 +559,11 @@ func (g *goGen) genType(sym *symbol) {
return return
} }
_, isptr := sym.GoType().(*types.Pointer)
if isptr {
return
}
g.Printf("\n// --- wrapping %s ---\n\n", sym.gofmt()) g.Printf("\n// --- wrapping %s ---\n\n", sym.gofmt())
g.Printf("//export %[1]s\n", sym.cgoname) g.Printf("//export %[1]s\n", sym.cgoname)
g.Printf("// %[1]s wraps %[2]s\n", sym.cgoname, sym.gofmt()) g.Printf("// %[1]s wraps %[2]s\n", sym.cgoname, sym.gofmt())
@ -813,11 +848,10 @@ func (g *goGen) genTypeTPCall(sym *symbol) {
} }
func (g *goGen) genTypeMethods(sym *symbol) { func (g *goGen) genTypeMethods(sym *symbol) {
if !sym.isNamed() { typ, ok := sym.GoType().(*types.Named)
if !ok {
return return
} }
typ := sym.GoType().(*types.Named)
for imeth := 0; imeth < typ.NumMethods(); imeth++ { for imeth := 0; imeth < typ.NumMethods(); imeth++ {
m := typ.Method(imeth) m := typ.Method(imeth)
if !m.Exported() { if !m.Exported() {

View File

@ -34,6 +34,7 @@ const (
skType skType
skArray skArray
skBasic skBasic
skBuiltin
skInterface skInterface
skMap skMap
skNamed skNamed
@ -52,6 +53,7 @@ var (
"type": skType, "type": skType,
"array": skArray, "array": skArray,
"basic": skBasic, "basic": skBasic,
"builtin": skBuiltin,
"interface": skInterface, "interface": skInterface,
"map": skMap, "map": skMap,
"named": skNamed, "named": skNamed,
@ -111,6 +113,10 @@ func (s symbol) isBasic() bool {
return (s.kind & skBasic) != 0 return (s.kind & skBasic) != 0
} }
func (s symbol) isBuiltin() bool {
return (s.kind & skBuiltin) != 0
}
func (s symbol) isArray() bool { func (s symbol) isArray() bool {
return (s.kind & skArray) != 0 return (s.kind & skArray) != 0
} }
@ -393,7 +399,7 @@ func (sym *symtab) addType(obj types.Object, t types.Type) {
goname: n, goname: n,
cgoname: "cgo_type_" + id, cgoname: "cgo_type_" + id,
cpyname: "cpy_type_" + id, cpyname: "cpy_type_" + id,
pyfmt: bsym.pyfmt, pyfmt: "O",
pybuf: bsym.pybuf, pybuf: bsym.pybuf,
pysig: "object", pysig: "object",
c2py: "cgopy_cnv_c2py_" + id, c2py: "cgopy_cnv_c2py_" + id,
@ -730,7 +736,7 @@ func init() {
gopkg: look("bool").Pkg(), gopkg: look("bool").Pkg(),
goobj: look("bool"), goobj: look("bool"),
gotyp: look("bool").Type(), gotyp: look("bool").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "bool", goname: "bool",
cgoname: "GoUint8", cgoname: "GoUint8",
cpyname: "GoUint8", cpyname: "GoUint8",
@ -746,7 +752,7 @@ func init() {
gopkg: look("byte").Pkg(), gopkg: look("byte").Pkg(),
goobj: look("byte"), goobj: look("byte"),
gotyp: look("byte").Type(), gotyp: look("byte").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "byte", goname: "byte",
cpyname: "uint8_t", cpyname: "uint8_t",
cgoname: "GoUint8", cgoname: "GoUint8",
@ -760,7 +766,7 @@ func init() {
gopkg: look("int").Pkg(), gopkg: look("int").Pkg(),
goobj: look("int"), goobj: look("int"),
gotyp: look("int").Type(), gotyp: look("int").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "int", goname: "int",
cpyname: "int", cpyname: "int",
cgoname: "GoInt", cgoname: "GoInt",
@ -776,7 +782,7 @@ func init() {
gopkg: look("int8").Pkg(), gopkg: look("int8").Pkg(),
goobj: look("int8"), goobj: look("int8"),
gotyp: look("int8").Type(), gotyp: look("int8").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "int8", goname: "int8",
cpyname: "int8_t", cpyname: "int8_t",
cgoname: "GoInt8", cgoname: "GoInt8",
@ -792,7 +798,7 @@ func init() {
gopkg: look("int16").Pkg(), gopkg: look("int16").Pkg(),
goobj: look("int16"), goobj: look("int16"),
gotyp: look("int16").Type(), gotyp: look("int16").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "int16", goname: "int16",
cpyname: "int16_t", cpyname: "int16_t",
cgoname: "GoInt16", cgoname: "GoInt16",
@ -808,7 +814,7 @@ func init() {
gopkg: look("int32").Pkg(), gopkg: look("int32").Pkg(),
goobj: look("int32"), goobj: look("int32"),
gotyp: look("int32").Type(), gotyp: look("int32").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "int32", goname: "int32",
cpyname: "int32_t", cpyname: "int32_t",
cgoname: "GoInt32", cgoname: "GoInt32",
@ -824,7 +830,7 @@ func init() {
gopkg: look("int64").Pkg(), gopkg: look("int64").Pkg(),
goobj: look("int64"), goobj: look("int64"),
gotyp: look("int64").Type(), gotyp: look("int64").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "int64", goname: "int64",
cpyname: "int64_t", cpyname: "int64_t",
cgoname: "GoInt64", cgoname: "GoInt64",
@ -840,7 +846,7 @@ func init() {
gopkg: look("uint").Pkg(), gopkg: look("uint").Pkg(),
goobj: look("uint"), goobj: look("uint"),
gotyp: look("uint").Type(), gotyp: look("uint").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "uint", goname: "uint",
cpyname: "unsigned int", cpyname: "unsigned int",
cgoname: "GoUint", cgoname: "GoUint",
@ -856,7 +862,7 @@ func init() {
gopkg: look("uint8").Pkg(), gopkg: look("uint8").Pkg(),
goobj: look("uint8"), goobj: look("uint8"),
gotyp: look("uint8").Type(), gotyp: look("uint8").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "uint8", goname: "uint8",
cpyname: "uint8_t", cpyname: "uint8_t",
cgoname: "GoUint8", cgoname: "GoUint8",
@ -872,7 +878,7 @@ func init() {
gopkg: look("uint16").Pkg(), gopkg: look("uint16").Pkg(),
goobj: look("uint16"), goobj: look("uint16"),
gotyp: look("uint16").Type(), gotyp: look("uint16").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "uint16", goname: "uint16",
cpyname: "uint16_t", cpyname: "uint16_t",
cgoname: "GoUint16", cgoname: "GoUint16",
@ -888,7 +894,7 @@ func init() {
gopkg: look("uint32").Pkg(), gopkg: look("uint32").Pkg(),
goobj: look("uint32"), goobj: look("uint32"),
gotyp: look("uint32").Type(), gotyp: look("uint32").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "uint32", goname: "uint32",
cpyname: "uint32_t", cpyname: "uint32_t",
cgoname: "GoUint32", cgoname: "GoUint32",
@ -904,7 +910,7 @@ func init() {
gopkg: look("uint64").Pkg(), gopkg: look("uint64").Pkg(),
goobj: look("uint64"), goobj: look("uint64"),
gotyp: look("uint64").Type(), gotyp: look("uint64").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "uint64", goname: "uint64",
cpyname: "uint64_t", cpyname: "uint64_t",
cgoname: "GoUint64", cgoname: "GoUint64",
@ -920,7 +926,7 @@ func init() {
gopkg: look("float32").Pkg(), gopkg: look("float32").Pkg(),
goobj: look("float32"), goobj: look("float32"),
gotyp: look("float32").Type(), gotyp: look("float32").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "float32", goname: "float32",
cpyname: "float", cpyname: "float",
cgoname: "GoFloat32", cgoname: "GoFloat32",
@ -936,7 +942,7 @@ func init() {
gopkg: look("float64").Pkg(), gopkg: look("float64").Pkg(),
goobj: look("float64"), goobj: look("float64"),
gotyp: look("float64").Type(), gotyp: look("float64").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "float64", goname: "float64",
cpyname: "double", cpyname: "double",
cgoname: "GoFloat64", cgoname: "GoFloat64",
@ -952,7 +958,7 @@ func init() {
gopkg: look("complex64").Pkg(), gopkg: look("complex64").Pkg(),
goobj: look("complex64"), goobj: look("complex64"),
gotyp: look("complex64").Type(), gotyp: look("complex64").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "complex64", goname: "complex64",
cpyname: "float complex", cpyname: "float complex",
cgoname: "GoComplex64", cgoname: "GoComplex64",
@ -968,7 +974,7 @@ func init() {
gopkg: look("complex128").Pkg(), gopkg: look("complex128").Pkg(),
goobj: look("complex128"), goobj: look("complex128"),
gotyp: look("complex128").Type(), gotyp: look("complex128").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "complex128", goname: "complex128",
cpyname: "double complex", cpyname: "double complex",
cgoname: "GoComplex128", cgoname: "GoComplex128",
@ -984,7 +990,7 @@ func init() {
gopkg: look("string").Pkg(), gopkg: look("string").Pkg(),
goobj: look("string"), goobj: look("string"),
gotyp: look("string").Type(), gotyp: look("string").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "string", goname: "string",
cpyname: "GoString", cpyname: "GoString",
cgoname: "GoString", cgoname: "GoString",
@ -1000,7 +1006,7 @@ func init() {
gopkg: look("rune").Pkg(), gopkg: look("rune").Pkg(),
goobj: look("rune"), goobj: look("rune"),
gotyp: look("rune").Type(), gotyp: look("rune").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "rune", goname: "rune",
cpyname: "GoRune", cpyname: "GoRune",
cgoname: "GoRune", cgoname: "GoRune",
@ -1016,7 +1022,7 @@ func init() {
gopkg: look("error").Pkg(), gopkg: look("error").Pkg(),
goobj: look("error"), goobj: look("error"),
gotyp: look("error").Type(), gotyp: look("error").Type(),
kind: skType | skInterface, kind: skType | skInterface | skBuiltin,
goname: "error", goname: "error",
cgoname: "GoInterface", cgoname: "GoInterface",
cpyname: "GoInterface", cpyname: "GoInterface",
@ -1033,7 +1039,7 @@ func init() {
gopkg: look("int").Pkg(), gopkg: look("int").Pkg(),
goobj: look("int"), goobj: look("int"),
gotyp: look("int").Type(), gotyp: look("int").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "int", goname: "int",
cpyname: "int64_t", cpyname: "int64_t",
cgoname: "GoInt", cgoname: "GoInt",
@ -1048,7 +1054,7 @@ func init() {
gopkg: look("uint").Pkg(), gopkg: look("uint").Pkg(),
goobj: look("uint"), goobj: look("uint"),
gotyp: look("uint").Type(), gotyp: look("uint").Type(),
kind: skType | skBasic, kind: skType | skBasic | skBuiltin,
goname: "uint", goname: "uint",
cpyname: "uint64_t", cpyname: "uint64_t",
cgoname: "GoUint", cgoname: "GoUint",