You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
// Compile the model
model.compile(optimizer=tf_keras.optimizers.Adam(), loss=lambda y, p_y: -p_y.log_prob(y), );
model.fit(x_train, y_train, epochs=3, );
with the following code how can i use tfp.math.psd_kernels.Polynomial, with feature_ndims=6, i ussually get an error , please help.
code =
class PolynomialKernelFn(tf_keras.layers.Layer):
def init(self, bias_amplitude=0.0, slope_amplitude=1.0, shift=0.0, exponent=3.0, feature_ndims=1, **kwargs):
super(PolynomialKernelFn, self).init(**kwargs)
dtype = kwargs.get('dtype', None)
self.jitter = 1e-4; # Adding a small jitter for numerical stability
def call(self, inputs): return self.kernel;
//end_class
//Example data
x_train = np.random.uniform(-3., 3., size=(100, 6));
y_train = np.sin(x_train[:, 0:2]); y_train = x_train;
print( 'x_train', x_train.shape, 'y_train', y_train.shape, );
x_range = [np.min(x_train), np.max(x_train)];
num_inducing_points = 18;
induc_idx_points = np.random.uniform(-3.0, 3.0, size=(num_inducing_points, x_train.shape[1]))
induc_idx_point_init = tf.constant_initializer(induc_idx_points)
print('induc_idx_points shape:', induc_idx_points.shape)
noise0 = ( tf.constant_initializer(np.array(0.54).astype(x_train.dtype)) );
// noise0 = ( tf.constant_initializer(np.array([ [0.54] ]).astype(x_train.dtype)) );
print( 'noise0', noise0.value.shape, );
polynomial_kernel = tfp.math.psd_kernels.Polynomial(
bias_amplitude=1.0, slope_amplitude=1.0, shift=0.0, exponent=2, feature_ndims=6, validate_args=True );
// Build model
model = tf_keras.Sequential([
tf_keras.layers.InputLayer( input_shape=(x_train.shape[1],) ),
]);
model.summary();
// Compile the model
model.compile(optimizer=tf_keras.optimizers.Adam(), loss=lambda y, p_y: -p_y.log_prob(y), );
model.fit(x_train, y_train, epochs=3, );
// Predict
x_test = np.linspace(-3., 3., 30).reshape(-1, 6);
y_pred = model.predict(x_test);
error = TypeError: Eager execution of tf.constant with unsupported shape. Tensor [[ 2.6181169 -0.63000023 0.88848853 2.8256464 0.7070044 0.10845122]
[ 1.7892019 0.9782654 0.7722825 -1.3020381 -2.2474375 -2.7869945 ]
[-1.7044429 1.3536608 1.6787233 -1.0234025 -0.56131506 -1.3617553 ]
[-0.67438734 -2.781567 -2.8270533 1.1969172 0.6445283 -2.27967 ]
[ 2.360385 -0.04742741 0.59732205 -2.933202 0.7861253 2.8390534 ]
[ 2.8995507 -2.7624567 1.4198784 2.6410117 -1.4379357 -1.7750715 ]
[ 1.152713 -1.1086357 -1.0503062 -1.4047132 2.0060854 -1.1540136 ]
[-2.5986419 0.97180235 1.4190216 -1.3796597 0.29559672 -1.4235702 ]
[ 2.4803004 2.498159 -0.7950366 1.8512287 2.291289 -0.8069504 ]
[ 2.9575715 1.8804133 2.8104806 -1.9198297 -1.3981376 2.7783432 ]
[ 0.5030356 -1.099595 0.3222227 0.6970968 -0.53885055 -0.7117024 ]
[ 2.3374352 -1.5795385 2.6345415 0.8455149 2.956886 0.8089901 ]
[ 2.72245 -0.10884786 -0.5315385 -0.4300995 -1.4305451 0.99511087]
[ 1.7424549 2.3463657 2.444201 0.5814336 -0.8842765 0.9178352 ]
[ 2.8574386 2.663537 1.0705398 0.31068817 0.22017056 1.424388 ]
[ 1.5786017 1.0860821 -2.569758 -0.1683704 -0.87796944 2.7431 ]
[-2.030426 -2.789794 -1.2321734 -0.4800363 1.2280728 2.926302 ]
[ 1.3148888 2.8946493 2.2137816 -1.0343827 1.1032671 -0.5407205 ]] (converted from [[ 2.61811684 -0.63000021 0.88848855 2.82564638 0.70700445 0.10845122]
[ 1.78920188 0.97826541 0.77228246 -1.30203812 -2.24743741 -2.78699451]
[-1.70444284 1.35366083 1.67872335 -1.02340239 -0.56131508 -1.36175523]
[-0.67438734 -2.78156708 -2.82705339 1.19691717 0.64452832 -2.27967009]
[ 2.36038489 -0.04742741 0.59732205 -2.93320213 0.78612532 2.83905331]
[ 2.89955057 -2.76245673 1.41987839 2.64101165 -1.43793566 -1.77507154]
[ 1.15271291 -1.10863564 -1.05030615 -1.40471315 2.00608531 -1.15401369]
[-2.59864192 0.97180237 1.41902165 -1.37965962 0.29559672 -1.4235702 ]
[ 2.48030032 2.49815891 -0.79503662 1.85122868 2.291289 -0.80695038]
[ 2.95757162 1.88041332 2.81048062 -1.91982973 -1.39813751 2.77834323]
[ 0.50303558 -1.09959492 0.3222227 0.69709683 -0.53885057 -0.71170238]
[ 2.33743521 -1.57953841 2.63454156 0.84551489 2.95688615 0.80899011]
[ 2.72245013 -0.10884786 -0.53153849 -0.43009949 -1.43054515 0.99511088]
[ 1.74245491 2.3463657 2.44420105 0.58143359 -0.88427652 0.91783519]
[ 2.85743859 2.66353695 1.0705398 0.31068815 0.22017056 1.42438809]
[ 1.5786017 1.08608213 -2.56975798 -0.16837039 -0.87796944 2.74310004]
[-2.03042604 -2.78979406 -1.23217341 -0.48003628 1.22807274 2.92630186]
[ 1.31488889 2.89464926 2.21378162 -1.03438269 1.10326708 -0.5407205 ]]) has 108 elements, but got
shape
(1, 18, None, 1) with None elements).The text was updated successfully, but these errors were encountered: