我们从Python开源项目中,提取了以下8个代码示例,用于说明如何使用tensorflow.RegisterGradient()。
def tf_mod(x, y, name=None): """Differentiable mod based in numpy Args x: first argument y: second argument Returns mod between x and y """ def np_mod(x, y): return np.mod(x, y, dtype=np.float32) def modgrad(op, grad): x = op.inputs[0] # the first argument (normally you need those to calculate the gradient, like the gradient of x^2 is 2x. ) y = op.inputs[1] # the second argument return grad * 1, grad * 0 #the propagated gradient with respect to the first and second argument respectively def py_func(func, inp, Tout, stateful=True, name=None, grad=None): # Need to generate a unique name to avoid duplicates: rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8)) tf.RegisterGradient(rnd_name)(grad) # see _MySquareGrad for grad example g = tf.get_default_graph() with g.gradient_override_map({"PyFunc": rnd_name}): return tf.py_func(func, inp, Tout, stateful=stateful, name=name) with ops.name_scope(name, "mod", [x,y]) as name: z = py_func(np_mod, [x,y], [tf.float32], name=name, grad=modgrad) # <-- here's the call to the gradient return tf.reshape(z[0], tf.shape(x))
def py_func_grad(func, inp, Tout, stateful=True, name=None, grad=None): """Custom py_func with gradient support """ # Need to generate a unique name to avoid duplicates: rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8)) tf.RegisterGradient(rnd_name)(grad) g = tf.get_default_graph() with g.gradient_override_map({ "PyFunc": rnd_name, "PyFuncStateless": rnd_name}): return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
def py_func(func, inp, Tout, stateful=True, name=None, grad=None): # Need to generate a unique name to avoid duplicates: rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1000000)) tf.RegisterGradient(rnd_name)(grad) # see _MySquareGrad for grad example g = tf.get_default_graph() with g.gradient_override_map({"PyFunc": rnd_name}): return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
def __py_func(func, inp, Tout, stateful=False, name=None, grad=None): # Need to generate a unique name to avoid duplicates: rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8)) tf.RegisterGradient(rnd_name)(grad) g = tf.get_default_graph() with g.gradient_override_map({"PyFunc": rnd_name, "PyFuncStateless": rnd_name}): return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
def py_func(func, inp, Tout, stateful=True, name=None, grad=None): # Need to generate a unique name to avoid duplicates: rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8)) tf.RegisterGradient(rnd_name)(grad) # see _MySquareGrad for grad example g = tf.get_default_graph() with g.gradient_override_map({"PyFunc": rnd_name}): return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
def __init__(self, graph, session, y, x): """Constructs a GuidedBackprop SaliencyMask.""" super(GuidedBackprop, self).__init__(graph, session, y, x) self.x = x if GuidedBackprop.GuidedReluRegistered is False: @tf.RegisterGradient("GuidedRelu") def _GuidedReluGrad(op, grad): gate_g = tf.cast(grad > 0, "float32") gate_y = tf.cast(op.outputs[0] > 0, "float32") return gate_y * gate_g * grad GuidedBackprop.GuidedReluRegistered = True with graph.as_default(): saver = tf.train.Saver() saver.save(session, '/tmp/guided_backprop_ckpt') graph_def = graph.as_graph_def() self.guided_graph = tf.Graph() with self.guided_graph.as_default(): self.guided_sess = tf.Session(graph=self.guided_graph) with self.guided_graph.gradient_override_map({'Relu': 'GuidedRelu'}): tf.import_graph_def(graph_def, name='') saver.restore(self.guided_sess, '/tmp/guided_backprop_ckpt') imported_y = self.guided_graph.get_tensor_by_name(y.name) imported_x = self.guided_graph.get_tensor_by_name(x.name) self.guided_grads_node = tf.gradients( imported_y, imported_x)[0]
def py_func(func, inp, Tout, name=None, grad=None): """Redfine tf.py_func to include gradients""" temp_name = next(tempfile._get_candidate_names()) _name = 'PyFuncGrad%s' %temp_name; tf.RegisterGradient(_name)(grad) g = tf.get_default_graph() with g.gradient_override_map({"PyFunc": _name}): return tf.py_func(func, inp, Tout, name=name)