Я создал переменную tf.Variable
с tf.Variable
. Интересно, почему, если я вызываю tf.get_variable
с тем же именем, исключение не создается и создается новая переменная с добавленным именем?
import tensorflow as tf
class QuestionTest(tf.test.TestCase):
def test_version(self):
self.assertEqual(tf.__version__, '1.10.1')
def test_variable(self):
a = tf.Variable(0., trainable=False, name='test')
self.assertEqual(a.name, "test:0")
b = tf.get_variable('test', shape=(), trainable=False)
self.assertEqual(b.name, "test_1:0")
self.assertNotEqual(a, b, msg=''a' is not 'b'')
with self.assertRaises(ValueError) as ecm:
tf.get_variable('test', shape=(), trainable=False)
exception = ecm.exception
self.assertStartsWith(str(exception), "Variable test already exists, disallowed.")
Это связано с тем, что tf.Variable
- это метод низкого уровня, в котором хранится созданная переменная в коллекции GLOBALS (или LOCALS), в то время как tf.get_variable
сохраняет учетную запись переменной, которую она создала, сохраняя их в хранилище переменных.
Когда вы сначала вызываете tf.Variable
, созданная переменная не добавляется в хранилище переменных, позволяя думать, что никакой переменной с именем "test"
не было создано.
Итак, когда вы позже tf.get_variable("test")
он будет смотреть на хранилище переменных, см., Что в нем нет переменной с именем "test"
.
Таким образом, он вызовет tf.Variable
, который создаст переменную с добавленным именем "test_1"
хранящимся в хранилище переменных под ключом "test"
.
import tensorflow as tf
class AnswerTest(tf.test.TestCase):
def test_version(self):
self.assertEqual(tf.__version__, '1.10.1')
def test_variable_answer(self):
"""Using the default variable scope"""
# Let first check the __variable_store and the GLOBALS collections.
self.assertListEqual(tf.get_collection(("__variable_store",)), [],
"No variable store.")
self.assertListEqual(tf.global_variables(), [],
"No global variables")
a = tf.Variable(0., trainable=False, name='test')
self.assertEqual(a.name, "test:0")
self.assertListEqual(tf.get_collection(("__variable_store",)), [],
"No variable store.")
self.assertListEqual(tf.global_variables(), [a],
"but 'a' is in global variables.")
b = tf.get_variable('test', shape=(), trainable=False)
self.assertNotEqual(a, b, msg=''a' is not 'b'')
self.assertEqual(b.name, "test_1:0", msg="'b' name is not 'test'.")
self.assertTrue(len(tf.get_collection(("__variable_store",))) > 0,
"There is now a variable store.")
var_store = tf.get_collection(("__variable_store",))[0]
self.assertDictEqual(var_store._vars, {"test": b},
"and variable 'b' is in it.")
self.assertListEqual(tf.global_variables(), [a, b],
"while 'a' and 'b' are in global variables.")
with self.assertRaises(ValueError) as exception_context_manager:
tf.get_variable('test', shape=(), trainable=False)
exception = exception_context_manager.exception
self.assertStartsWith(str(exception),
"Variable test already exists, disallowed.")
То же самое верно при использовании явной переменной.
def test_variable_answer_with_variable_scope(self):
"""Using now a variable scope"""
self.assertListEqual(tf.get_collection(("__variable_store",)), [],
"No variable store.")
with tf.variable_scope("my_scope") as scope:
self.assertTrue(len(tf.get_collection(("__variable_store",))) > 0,
"There is now a variable store.")
var_store = tf.get_collection(("__variable_store",))[0]
self.assertDictEqual(var_store._vars, {},
"but with variable in it.")
a = tf.Variable(0., trainable=False, name='test')
self.assertEqual(a.name, "my_scope/test:0")
var_store = tf.get_collection(("__variable_store",))[0]
self.assertDictEqual(var_store._vars, {},
"Still no variable in the store.")
b = tf.get_variable('test', shape=(), trainable=False)
self.assertEqual(b.name, "my_scope/test_1:0")
var_store = tf.get_collection(("__variable_store",))[0]
self.assertDictEqual(
var_store._vars, {"my_scope/test": b},
"'b' is in the store, but notice the difference between its name and its key in the store.")
with self.assertRaises(ValueError) as exception_context_manager:
tf.get_variable('test', shape=(), trainable=False)
exception = exception_context_manager.exception
self.assertStartsWith(str(exception),
"Variable my_scope/test already exists, disallowed.")