From 5fde2fd2a65b4c1dd7ffa7cce95b745e5cc8463c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B5=85=E6=A2=A6?= Date: Fri, 17 Apr 2026 11:01:27 +0800 Subject: [PATCH 1/4] Fix shared embeddings and string sparse validation --- deepctr/feature_column.py | 19 ++++++++++++++ deepctr/inputs.py | 54 +++++++++++++++++++++++++++++---------- tests/feature_test.py | 27 ++++++++++++++++++++ tests/models/MTL_test.py | 11 ++++++++ 4 files changed, 97 insertions(+), 14 deletions(-) diff --git a/deepctr/feature_column.py b/deepctr/feature_column.py index a80a6f84..31629f95 100644 --- a/deepctr/feature_column.py +++ b/deepctr/feature_column.py @@ -14,6 +14,23 @@ DEFAULT_GROUP_NAME = "default_group" +def _is_string_dtype(dtype): + try: + return tf.as_dtype(dtype) == tf.string + except TypeError: + return dtype == "string" + + +def _check_sparse_feature_dtype(fc): + if _is_string_dtype(fc.dtype) and not fc.use_hash: + raise ValueError( + "SparseFeat(name='{}', dtype='string') requires use_hash=True " + "so string ids can be converted before embedding lookup. " + "Alternatively, encode the feature values to integer ids before " + "passing them to DeepCTR.".format(fc.name) + ) + + class SparseFeat(namedtuple('SparseFeat', ['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'vocabulary_path', 'dtype', 'embeddings_initializer', 'embedding_name', @@ -129,12 +146,14 @@ def build_input_features(feature_columns, prefix=''): input_features = OrderedDict() for fc in feature_columns: if isinstance(fc, SparseFeat): + _check_sparse_feature_dtype(fc) input_features[fc.name] = Input( shape=(1,), name=prefix + fc.name, dtype=fc.dtype) elif isinstance(fc, DenseFeat): input_features[fc.name] = Input( shape=(fc.dimension,), name=prefix + fc.name, dtype=fc.dtype) elif isinstance(fc, VarLenSparseFeat): + _check_sparse_feature_dtype(fc) input_features[fc.name] = Input(shape=(fc.maxlen,), name=prefix + fc.name, dtype=fc.dtype) if fc.weight_name is not None: diff --git a/deepctr/inputs.py b/deepctr/inputs.py index 3c2bdbae..aeadc439 100644 --- a/deepctr/inputs.py +++ b/deepctr/inputs.py @@ -16,6 +16,27 @@ from .layers.utils import Hash +def _create_embedding_layer(feat, l2_reg, prefix, name_suffix, mask_zero=False): + emb = Embedding(feat.vocabulary_size, feat.embedding_dim, + embeddings_initializer=feat.embeddings_initializer, + embeddings_regularizer=l2(l2_reg), + name=prefix + '_' + name_suffix + '_' + feat.embedding_name, + mask_zero=mask_zero) + emb.trainable = feat.trainable + return emb + + +def _check_embedding_compatible(embedding_name, existing_feat, feat): + for attr in ('vocabulary_size', 'embedding_dim', 'trainable'): + if getattr(existing_feat, attr) != getattr(feat, attr): + raise ValueError( + "Feature columns with the same embedding_name must share the same " + "{}. embedding_name='{}' has {} and {}.".format( + attr, embedding_name, getattr(existing_feat, attr), getattr(feat, attr) + ) + ) + + def get_inputs_list(inputs): return list(chain(*list(map(lambda x: x.values(), filter(lambda x: x is not None, inputs))))) @@ -23,25 +44,30 @@ def get_inputs_list(inputs): def create_embedding_dict(sparse_feature_columns, varlen_sparse_feature_columns, seed, l2_reg, prefix='sparse_', seq_mask_zero=True): sparse_embedding = {} + embedding_feature_dict = {} + varlen_embedding_names = set( + feat.embedding_name for feat in varlen_sparse_feature_columns + ) if varlen_sparse_feature_columns else set() + for feat in sparse_feature_columns: - emb = Embedding(feat.vocabulary_size, feat.embedding_dim, - embeddings_initializer=feat.embeddings_initializer, - embeddings_regularizer=l2(l2_reg), - name=prefix + '_emb_' + feat.embedding_name) - emb.trainable = feat.trainable - sparse_embedding[feat.embedding_name] = emb + embedding_name = feat.embedding_name + if embedding_name in sparse_embedding: + _check_embedding_compatible(embedding_name, embedding_feature_dict[embedding_name], feat) + continue + mask_zero = seq_mask_zero and feat.embedding_name in varlen_embedding_names + emb = _create_embedding_layer(feat, l2_reg, prefix, 'emb', mask_zero) + sparse_embedding[embedding_name] = emb + embedding_feature_dict[embedding_name] = feat if varlen_sparse_feature_columns and len(varlen_sparse_feature_columns) > 0: for feat in varlen_sparse_feature_columns: - # if feat.name not in sparse_embedding: - emb = Embedding(feat.vocabulary_size, feat.embedding_dim, - embeddings_initializer=feat.embeddings_initializer, - embeddings_regularizer=l2( - l2_reg), - name=prefix + '_seq_emb_' + feat.name, - mask_zero=seq_mask_zero) - emb.trainable = feat.trainable + embedding_name = feat.embedding_name + if embedding_name in sparse_embedding: + _check_embedding_compatible(embedding_name, embedding_feature_dict[embedding_name], feat) + continue + emb = _create_embedding_layer(feat, l2_reg, prefix, 'seq_emb', seq_mask_zero) sparse_embedding[feat.embedding_name] = emb + embedding_feature_dict[feat.embedding_name] = feat return sparse_embedding diff --git a/tests/feature_test.py b/tests/feature_test.py index 35005fb7..81de806f 100644 --- a/tests/feature_test.py +++ b/tests/feature_test.py @@ -1,6 +1,8 @@ from deepctr.models import DeepFM from deepctr.feature_column import SparseFeat, DenseFeat, VarLenSparseFeat, get_feature_names +from deepctr.inputs import create_embedding_matrix import numpy as np +import pytest def test_long_dense_vector(): @@ -28,3 +30,28 @@ def test_feature_column_sparsefeat_vocabulary_path(): vlsf = VarLenSparseFeat(sf, 6) if vlsf.vocabulary_path != vocab_path: raise ValueError("vlsf.vocabulary_path is invalid") + + +def test_create_embedding_matrix_reuses_same_embedding_name(): + feature_columns = [ + SparseFeat('item_id', 4, embedding_dim=8), + SparseFeat('item_id_copy', 4, embedding_dim=8, embedding_name='item_id'), + VarLenSparseFeat(SparseFeat('hist_item_id', 4, embedding_dim=8, embedding_name='item_id'), maxlen=3), + VarLenSparseFeat(SparseFeat('neg_hist_item_id', 4, embedding_dim=8, embedding_name='item_id'), maxlen=3), + ] + + embedding_dict = create_embedding_matrix(feature_columns, l2_reg=0, seed=1024) + + assert list(embedding_dict.keys()) == ['item_id'] + assert embedding_dict['item_id'].name == 'sparse_emb_item_id' + assert embedding_dict['item_id'].mask_zero is True + + +def test_create_embedding_matrix_rejects_inconsistent_shared_embedding(): + feature_columns = [ + SparseFeat('item_id', 4, embedding_dim=8), + VarLenSparseFeat(SparseFeat('hist_item_id', 5, embedding_dim=8, embedding_name='item_id'), maxlen=3), + ] + + with pytest.raises(ValueError, match="same embedding_name"): + create_embedding_matrix(feature_columns, l2_reg=0, seed=1024) diff --git a/tests/models/MTL_test.py b/tests/models/MTL_test.py index a18a6b64..5da008d3 100644 --- a/tests/models/MTL_test.py +++ b/tests/models/MTL_test.py @@ -1,6 +1,7 @@ import pytest import tensorflow as tf +from deepctr.feature_column import SparseFeat from deepctr.models.multitask import SharedBottom, ESMM, MMOE, PLE from ..utils_mtl import get_mtl_test_data, check_mtl_model @@ -27,6 +28,16 @@ def test_ESMM(): check_mtl_model(model, model_name, x, y_list, task_types=['binary', 'binary']) +def test_ESMM_string_sparse_requires_hash(): + with pytest.raises(ValueError, match="use_hash=True"): + ESMM([SparseFeat('user_id', 10, dtype='string')], tower_dnn_hidden_units=(8,)) + + +def test_ESMM_string_sparse_with_hash(): + model = ESMM([SparseFeat('user_id', 10, use_hash=True, dtype='string')], tower_dnn_hidden_units=(8,)) + assert len(model.outputs) == 2 + + def test_MMOE(): if tf.__version__ == "1.15.0": # slow in tf 1.15 return From d034630937cd4d3179d09932bf3391b0ec6ec06a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B5=85=E6=A2=A6?= Date: Wed, 22 Apr 2026 19:31:02 +0800 Subject: [PATCH 2/4] Fix DIEN shared embedding fixture --- tests/models/DIEN_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/DIEN_test.py b/tests/models/DIEN_test.py index 0cad9bf3..ec226c15 100644 --- a/tests/models/DIEN_test.py +++ b/tests/models/DIEN_test.py @@ -9,10 +9,10 @@ def get_xy_fd(use_neg=False, hash_flag=False): - feature_columns = [SparseFeat('user', 3, hash_flag), - SparseFeat('gender', 2, hash_flag), - SparseFeat('item', 3 + 1, hash_flag), - SparseFeat('item_gender', 2 + 1, hash_flag), + feature_columns = [SparseFeat('user', 3, use_hash=hash_flag), + SparseFeat('gender', 2, use_hash=hash_flag), + SparseFeat('item', 3 + 1, embedding_dim=8, use_hash=hash_flag), + SparseFeat('item_gender', 2 + 1, embedding_dim=4, use_hash=hash_flag), DenseFeat('score', 1)] feature_columns += [ From d8fa7fd0933fbcad7da19bb12c165c9b8a823b92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B5=85=E6=A2=A6?= Date: Wed, 22 Apr 2026 19:40:31 +0800 Subject: [PATCH 3/4] Sync ReadTheDocs build config --- .readthedocs.yml | 12 ++++++++++-- docs/requirements.readthedocs.txt | 14 ++++++++++++-- docs/source/conf.py | 6 +++++- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 1afb6e70..03a5ead0 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,5 +1,13 @@ +version: 2 + build: - image: latest + os: ubuntu-22.04 + tools: + python: "3.10" + +sphinx: + configuration: docs/source/conf.py python: - version: 3.6 \ No newline at end of file + install: + - requirements: docs/requirements.readthedocs.txt diff --git a/docs/requirements.readthedocs.txt b/docs/requirements.readthedocs.txt index 10942403..baf4a3dd 100644 --- a/docs/requirements.readthedocs.txt +++ b/docs/requirements.readthedocs.txt @@ -1,2 +1,12 @@ -tensorflow==2.6.2 -recommonmark==0.7.1 \ No newline at end of file +numpy<2 +Jinja2<3.1 +docutils<0.18 +sphinx==4.5.0 +sphinx-rtd-theme==0.5.2 +sphinxcontrib-applehelp==1.0.2 +sphinxcontrib-devhelp==1.0.2 +sphinxcontrib-htmlhelp==2.0.0 +sphinxcontrib-qthelp==1.0.3 +sphinxcontrib-serializinghtml==1.1.5 +recommonmark==0.7.1 +tensorflow==2.15.0 diff --git a/docs/source/conf.py b/docs/source/conf.py index a3b84055..63b2c5bd 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -44,6 +44,7 @@ 'sphinx.ext.ifconfig', 'sphinx.ext.viewcode', 'sphinx.ext.githubpages', + 'recommonmark', ] # Add any paths that contain templates here, relative to this directory. @@ -52,7 +53,10 @@ # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # -source_suffix = ['.rst', '.md'] +source_suffix = { + '.rst': 'restructuredtext', + '.md': 'markdown', +} #source_suffix = '.rst' # The master toctree document. From 242ec662c7b443ed6e6e2f8a570c82f2b5e5c97a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B5=85=E6=A2=A6?= Date: Sat, 25 Apr 2026 00:22:00 +0800 Subject: [PATCH 4/4] Avoid assert in new regression tests --- tests/feature_test.py | 9 ++++++--- tests/models/MTL_test.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/feature_test.py b/tests/feature_test.py index 81de806f..31690955 100644 --- a/tests/feature_test.py +++ b/tests/feature_test.py @@ -42,9 +42,12 @@ def test_create_embedding_matrix_reuses_same_embedding_name(): embedding_dict = create_embedding_matrix(feature_columns, l2_reg=0, seed=1024) - assert list(embedding_dict.keys()) == ['item_id'] - assert embedding_dict['item_id'].name == 'sparse_emb_item_id' - assert embedding_dict['item_id'].mask_zero is True + if list(embedding_dict.keys()) != ['item_id']: + raise AssertionError("Expected a single shared embedding keyed by 'item_id'") + if embedding_dict['item_id'].name != 'sparse_emb_item_id': + raise AssertionError("Expected the shared embedding layer to use the embedding_name-based layer name") + if embedding_dict['item_id'].mask_zero is not True: + raise AssertionError("Expected shared sequence embeddings to preserve mask_zero") def test_create_embedding_matrix_rejects_inconsistent_shared_embedding(): diff --git a/tests/models/MTL_test.py b/tests/models/MTL_test.py index 5da008d3..e4791863 100644 --- a/tests/models/MTL_test.py +++ b/tests/models/MTL_test.py @@ -35,7 +35,8 @@ def test_ESMM_string_sparse_requires_hash(): def test_ESMM_string_sparse_with_hash(): model = ESMM([SparseFeat('user_id', 10, use_hash=True, dtype='string')], tower_dnn_hidden_units=(8,)) - assert len(model.outputs) == 2 + if len(model.outputs) != 2: + raise AssertionError("Expected ESMM to build two task outputs") def test_MMOE():