diff --git a/src/MXNet.jl b/src/MXNet.jl index 9bae62d30..db47e7b7c 100644 --- a/src/MXNet.jl +++ b/src/MXNet.jl @@ -107,6 +107,9 @@ export AbstractDataProvider, # visualize.jl export to_graphviz +# NNVM +export Graph + ############################################################################### # includes ############################################################################### @@ -125,6 +128,8 @@ include("name.jl") include("symbolic-node.jl") include("executor.jl") +include("nnvm/graph.jl") + include("metric.jl") include("optimizer.jl") include("initializer.jl") diff --git a/src/base.jl b/src/base.jl index 271e35607..31d00e9cb 100644 --- a/src/base.jl +++ b/src/base.jl @@ -129,12 +129,13 @@ macro mx_define_handle_t(name, destructor) end end -@mx_define_handle_t(MX_NDArrayHandle, MXNDArrayFree) -@mx_define_handle_t(MX_OpHandle, nop) -@mx_define_handle_t(MX_SymbolHandle, MXSymbolFree) +@mx_define_handle_t(MX_NDArrayHandle, MXNDArrayFree) +@mx_define_handle_t(MX_OpHandle, nop) +@mx_define_handle_t(MX_SymbolHandle, MXSymbolFree) @mx_define_handle_t(MX_ExecutorHandle, MXExecutorFree) @mx_define_handle_t(MX_DataIterHandle, MXDataIterFree) -@mx_define_handle_t(MX_KVStoreHandle, MXKVStoreFree) +@mx_define_handle_t(MX_KVStoreHandle, MXKVStoreFree) +@mx_define_handle_t(NN_GraphHandle, NNGraphFree) ################################################################################ # MXNet Params diff --git a/src/nnvm/graph.jl b/src/nnvm/graph.jl new file mode 100644 index 000000000..367698df6 --- /dev/null +++ b/src/nnvm/graph.jl @@ -0,0 +1,86 @@ +struct Graph + handle::NN_GraphHandle +end + +function Graph(x::SymbolicNode) + h = Ref{MX_handle}(C_NULL) + @mxcall(:NNGraphCreate, (MX_handle, Ref{MX_handle}), x, h) + Graph(NN_GraphHandle(h[])) +end + +Base.unsafe_convert(::Type{MX_handle}, x::Graph) = + Base.unsafe_convert(MX_handle, x.handle) + +function getsymbol(x::Graph) + s = Ref{MX_handle}(C_NULL) + @mxcall(:NNGraphGetSymbol, (MX_handle, Ref{MX_handle}), x, s) + SymbolicNode(MX_SymbolHandle(s[])) +end + +function apply(x::Graph, pass::Symbol) + y = Ref{MX_handle}(C_NULL) + @mxcall(:NNGraphApplyPasses, (MX_handle, Cuint, char_pp, Ref{MX_handle}), + x, 1, [dump_mx_param(pass)], y) + Graph(NN_GraphHandle(y[])) +end + +function Base.getindex(x::Graph, k::Symbol) + json = Ref{char_p}(C_NULL) + success = Ref{Cint}(0) + @mxcall(:NNGraphGetJSONAttr, (MX_handle, char_p, Ref{char_p}, Ref{Cint}), + x, dump_mx_param(k), json, success) + success[] == 0 && throw(KeyError(k)) + typ, val = JSON.parse(unsafe_string(json[])) + + if typ == "str" + val + else + warn("unkown type $typ") + typ, val + end +end + +Base.setindex!(x::Graph, v, k::Symbol) = (x[:str, k] = v) + +function Base.setindex!(x::Graph, v, typ::Symbol, k::Symbol) + s = JSON.json([typ, v]) + @show s + @mxcall(:NNGraphSetJSONAttr, (MX_handle, Cstring, Cstring), + x, dump_mx_param(k), s) +end + +setshape!(x::Graph, shape::Tuple) = setshape!(x, [shape]) + +function setshape!(x::Graph, shapes::Vector) + x[:list_shape, :shape_inputs] = shapes + nothing +end + +setdtype!(x::Graph, t::Int) = setdtype!(x, [t]) + +function setdtype!(x::Graph, ts::Vector) + x[:list_int, :dtype_inputs] = ts + nothing +end + +function _set_node_attr!(x::Graph, k::Symbol, s::SymbolicNode) + @mxcall(:NNGraphSetNodeEntryListAttr_, (MX_handle, Cstring, MX_handle), + x, String(k), s) +end + +ir(x::SymbolicNode) = graphir(Graph(x)) +ir(x::Graph) = apply(x, :PrintGraphIR)[:graphir] + +function gradient(y::SymbolicNode, x::SymbolicNode) + g = Graph(y) + _set_node_attr!(g, :grad_ys, y) + _set_node_attr!(g, :grad_xs, x) + + # y could have multiple output + # ny = length(list_outputs(y)) + # ∇y = [ones_like(i) for i ∈ y] + ∇y = ones_like(y) + _set_node_attr!(g, :grad_ys_out_grad, ∇y) + + getsymbol(apply(g, :Gradient)) +end diff --git a/src/symbolic-node.jl b/src/symbolic-node.jl index bb3c97773..c11ac5eda 100644 --- a/src/symbolic-node.jl +++ b/src/symbolic-node.jl @@ -265,7 +265,7 @@ Base.show(io::IO, sym::SymbolicNode) = import Base: print function print(io::IO, sym::SymbolicNode) - out = Ref{mx.char_p}(C_NULL) + out = Ref{char_p}(C_NULL) @mx.mxcall(:MXSymbolPrint, (mx.MX_SymbolHandle, Ref{mx.char_p}), sym.handle, out) print(io, unsafe_string(out[])) end