Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:15

0001 # coding: utf-8
0002 
0003 """
0004 Test script to create a simple graph for testing purposes at bin/data and save it using the
0005 SavedModel serialization format.
0006 
0007 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md
0008 """
0009 
0010 import os
0011 import sys
0012 
0013 import cmsml
0014 
0015 
0016 # get tensorflow and work with the v1 compatibility layer
0017 tf, tf1, tf_version = cmsml.tensorflow.import_tf()
0018 tf = tf1
0019 tf.disable_eager_execution()
0020 
0021 # prepare the datadir
0022 if len(sys.argv) >= 2:
0023     datadir = sys.argv[1]
0024 else:
0025     thisdir = os.path.dirname(os.path.abspath(__file__))
0026     datadir = os.path.join(os.path.dirname(thisdir), "bin", "data")
0027 
0028 # create the graph
0029 x_ = tf.placeholder(tf.float32, [None, 10], name="input")
0030 scale_ = tf.placeholder(tf.float32, name="scale")
0031 
0032 W = tf.Variable(tf.ones([10, 1]))
0033 b = tf.Variable(tf.ones([1]))
0034 h = tf.add(tf.matmul(x_, W), b)
0035 y = tf.multiply(h, scale_, name="output")
0036 
0037 # Setup the script to run on CPU only
0038 config = tf.ConfigProto(
0039         device_count = {'GPU': 0}
0040     )
0041 sess = tf.Session(config=config)
0042 sess.run(tf.global_variables_initializer())
0043 
0044 print(sess.run(y, feed_dict={scale_: 1.0, x_: [range(10)]})[0][0])
0045 
0046 # write it
0047 builder = tf.saved_model.builder.SavedModelBuilder(os.path.join(datadir, "simplegraph"))
0048 builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
0049 builder.save()