Skip to content

Commit 1f24f0f

Browse files
ke like li
authored andcommitted
Update aix360.algorithms.rule_induction.trxf.pmml_export.models.data_dictionary.DataField for categorical string
Signed-off-by: Ke LI <[email protected]>
1 parent 34b2e02 commit 1f24f0f

File tree

5 files changed

+162
-78
lines changed

5 files changed

+162
-78
lines changed

aix360/algorithms/rule_induction/trxf/pmml_export/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from .data_dictionary import DataField
33
from .data_dictionary import DataType
44
from .data_dictionary import OpType
5+
from .data_dictionary import Value
6+
from .data_dictionary import Restriction
57
from .mining_schema import MiningField
68
from .mining_schema import MiningFieldUsageType
79
from .mining_schema import MiningSchema

aix360/algorithms/rule_induction/trxf/pmml_export/models/data_dictionary.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,29 @@ class OpType(enum.Enum):
2323
continuous = 2
2424

2525

26+
class Restriction(enum.Enum):
27+
valid = 0
28+
invalid = 1
29+
missing = 2
30+
31+
32+
@dataclass(frozen=True)
33+
class Value:
34+
value: str = field()
35+
property: Restriction = field(default=Restriction.valid)
36+
37+
2638
@dataclass(frozen=True)
2739
class DataField:
2840
name: str = field()
2941
optype: OpType = field()
3042
dataType: DataType = field()
43+
values: typing.Optional[typing.List[Value]] = field(default_factory=list)
44+
45+
def __post_init__(self):
46+
if self.values and \
47+
(self.dataType is not DataType.string or self.optype not in (OpType.ordinal, OpType.categorical)):
48+
raise ValueError
3149

3250

3351
@dataclass(frozen=True)

aix360/algorithms/rule_induction/trxf/pmml_export/serializer/nyoka_serializer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,17 @@ def _nyoka_pmml_model(
4949

5050
def _nyoka_data_dictionary(self, data_dictionary: models.DataDictionary) -> nyoka_pmml.DataDictionary:
5151
return nyoka_pmml.DataDictionary(
52-
numberOfFields=0 if data_dictionary.dataFields is None else len(data_dictionary.dataFields),
53-
DataField=None if data_dictionary.dataFields is None else [
54-
nyoka_pmml.DataField(name=f.name, optype=f.optype.name, dataType=f.dataType.name)
55-
for f in data_dictionary.dataFields])
52+
numberOfFields=0 if not data_dictionary.dataFields else len(data_dictionary.dataFields),
53+
DataField=None if not data_dictionary.dataFields else [
54+
self._nyoka_data_field(f) for f in data_dictionary.dataFields])
55+
56+
def _nyoka_data_field(self, data_field: models.DataField) -> nyoka_pmml.DataField:
57+
return nyoka_pmml.DataField(
58+
name=data_field.name,
59+
optype=data_field.optype.name,
60+
dataType=data_field.dataType.name,
61+
Value=None if not data_field.values else [
62+
nyoka_pmml.Value(value=val.value, property=val.property.name) for val in data_field.values])
5663

5764
def _nyoka_rule_set_model(self, rule_set_model: models.RuleSetModel) -> nyoka_pmml.RuleSetModel:
5865
return nyoka_pmml.RuleSetModel(
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import datetime
2+
from unittest import TestCase
3+
4+
import nyoka.base.constants as nyoka_constants
5+
import xmltodict
6+
7+
import aix360.algorithms.rule_induction.trxf.pmml_export.models as models
8+
import aix360.algorithms.rule_induction.trxf.pmml_export.serializer as serializer
9+
10+
11+
class TestNyokaSerializer(TestCase):
12+
now = datetime.datetime.now()
13+
nyokaSerializer = serializer.NyokaSerializer(timestamp=now)
14+
15+
def test_serialize_model_with_version_and_header(self):
16+
# arrange
17+
model = models.SimplePMMLRuleSetModel(dataDictionary=None, ruleSetModel=None) # noqa
18+
19+
# when
20+
serialized = self.nyokaSerializer.serialize(model)
21+
22+
# assert
23+
expected = '''
24+
<PMML xmlns="http://www.dmg.org/PMML-4_4" version="{version}">
25+
<Header copyright="Copyright IBM Corp, exported to PMML by Nyoka (c) 2022 Software AG"
26+
description="Default description">
27+
<Application name="SimpleRuleSetExport" version="0.0.1"/>
28+
<Timestamp>{time}</Timestamp>
29+
</Header>
30+
</PMML>
31+
'''.format(version=nyoka_constants.PMML_SCHEMA.VERSION, time=self.now)
32+
self.assertEqual(xmltodict.parse(xml_input=serialized), xmltodict.parse(xml_input=expected))
33+
34+
def test_serialize_data_dictionary(self):
35+
# arrange
36+
srz = serializer.NyokaSerializer()
37+
data_dictionary = models.DataDictionary(
38+
dataFields=[
39+
models.DataField(name='toto0', optype=models.OpType.continuous, dataType=models.DataType.float),
40+
models.DataField(name='toto1', optype=models.OpType.ordinal, dataType=models.DataType.string),
41+
models.DataField(name='toto2', optype=models.OpType.categorical, dataType=models.DataType.boolean),
42+
models.DataField(name='toto3', optype=models.OpType.categorical, dataType=models.DataType.integer)])
43+
model = models.SimplePMMLRuleSetModel(dataDictionary=data_dictionary, ruleSetModel=None) # noqa
44+
45+
# when
46+
serialized = self.nyokaSerializer.serialize(model)
47+
res_data_dictionary_dict = xmltodict.parse(xml_input=serialized)['PMML']['DataDictionary']
48+
49+
# assert
50+
expected = '''
51+
<DataDictionary numberOfFields="4">
52+
<DataField name="toto0" optype="continuous" dataType="float"/>
53+
<DataField name="toto1" optype="ordinal" dataType="string"/>
54+
<DataField name="toto2" optype="categorical" dataType="boolean"/>
55+
<DataField name="toto3" optype="categorical" dataType="integer"/>
56+
</DataDictionary>
57+
'''
58+
self.assertEqual(res_data_dictionary_dict, xmltodict.parse(xml_input=expected)['DataDictionary'])
59+
60+
def test_should_raise_error_if_categorical_value_is_not_string(self):
61+
with self.assertRaises(ValueError):
62+
models.DataDictionary(
63+
dataFields=[
64+
models.DataField(
65+
name='toto0',
66+
optype=models.OpType.categorical,
67+
dataType=models.DataType.float,
68+
values=[models.Value(value='unexpected')])])
69+
with self.assertRaises(ValueError):
70+
models.DataDictionary(
71+
dataFields=[
72+
models.DataField(
73+
name='toto0',
74+
optype=models.OpType.continuous,
75+
dataType=models.DataType.string,
76+
values=[models.Value(value='unexpected')])])
77+
78+
def test_serialize_data_dictionary_with_categorical_value(self):
79+
# arrange
80+
data_dictionary = models.DataDictionary(
81+
dataFields=[
82+
models.DataField(
83+
name='toto0',
84+
optype=models.OpType.categorical,
85+
dataType=models.DataType.string,
86+
values=[
87+
models.Value(value='val1'),
88+
models.Value(value='val2', property=models.Restriction.valid),
89+
models.Value(value='unknown', property=models.Restriction.invalid),
90+
models.Value(value='unknown', property=models.Restriction.missing)])])
91+
model = models.SimplePMMLRuleSetModel(dataDictionary=data_dictionary, ruleSetModel=None) # noqa
92+
93+
# when
94+
serialized = self.nyokaSerializer.serialize(model)
95+
res_data_dictionary_dict = xmltodict.parse(xml_input=serialized)['PMML']['DataDictionary']
96+
97+
# assert
98+
expected = '''
99+
<DataDictionary numberOfFields="1">
100+
<DataField name="toto0" optype="categorical" dataType="string">
101+
<Value value="val1"/>
102+
<Value value="val2"/>
103+
<Value value="unknown" property="invalid"/>
104+
<Value value="unknown" property="missing"/>
105+
</DataField>
106+
</DataDictionary>
107+
'''
108+
self.assertEqual(res_data_dictionary_dict, xmltodict.parse(xml_input=expected)['DataDictionary'])
109+
110+
def test_serialize_mining_schema(self):
111+
# arrange
112+
mining_schema = models.MiningSchema(
113+
miningFields=[
114+
models.MiningField(name='toto0', usageType=models.MiningFieldUsageType.active),
115+
models.MiningField(name='toto1', usageType=models.MiningFieldUsageType.target)])
116+
model = models.SimplePMMLRuleSetModel(
117+
dataDictionary=models.DataDictionary(dataFields=None), # noqa
118+
ruleSetModel=models.RuleSetModel(miningSchema=mining_schema, ruleSet=None)) # noqa
119+
120+
# when
121+
serialized = self.nyokaSerializer.serialize(model)
122+
res_data_dictionary_dict = xmltodict.parse(xml_input=serialized)['PMML']['RuleSetModel']['MiningSchema']
123+
124+
# assert
125+
expected = '''
126+
<MiningSchema>
127+
<MiningField name="toto0" usageType="active"/>
128+
<MiningField name="toto1" usageType="target"/>
129+
</MiningSchema>
130+
'''
131+
self.assertEqual(res_data_dictionary_dict, xmltodict.parse(xml_input=expected)['MiningSchema'])

tests/rule_induction/trxf/pmml_export/test_nyoka_serialize_ruleset.py

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,12 @@
11
from unittest import TestCase
22

3-
import nyoka.base.constants as nyoka_constants
43
import xmltodict
54

65
import aix360.algorithms.rule_induction.trxf.pmml_export.models as models
76
import aix360.algorithms.rule_induction.trxf.pmml_export.serializer as serializer
87

98

109
class TestNyokaSerializer(TestCase):
11-
def test_serialize_model_with_version_and_header(self):
12-
# arrange
13-
srz = serializer.NyokaSerializer()
14-
model = models.SimplePMMLRuleSetModel(dataDictionary=None, ruleSetModel=None) # noqa
15-
16-
# when
17-
serialized = srz.serialize(model)
18-
res_dict_no_timestamp = xmltodict.parse(xml_input=serialized)
19-
res_dict_no_timestamp['PMML']['Header'] = {
20-
k: res_dict_no_timestamp['PMML']['Header'][k]
21-
for k in res_dict_no_timestamp['PMML']['Header'] if k != 'Timestamp'}
22-
23-
# assert
24-
expected = '''
25-
<PMML xmlns="http://www.dmg.org/PMML-4_4" version="{version}">
26-
<Header copyright="Copyright IBM Corp, exported to PMML by Nyoka (c) 2022 Software AG"
27-
description="Default description">
28-
<Application name="SimpleRuleSetExport" version="0.0.1"/>
29-
</Header>
30-
</PMML>
31-
'''.format(version=nyoka_constants.PMML_SCHEMA.VERSION)
32-
self.assertEqual(res_dict_no_timestamp, xmltodict.parse(xml_input=expected))
33-
34-
def test_serialize_data_dictionary(self):
35-
# arrange
36-
srz = serializer.NyokaSerializer()
37-
data_dictionary = models.DataDictionary(
38-
dataFields=[
39-
models.DataField(name='toto0', optype=models.OpType.continuous, dataType=models.DataType.float),
40-
models.DataField(name='toto1', optype=models.OpType.ordinal, dataType=models.DataType.string),
41-
models.DataField(name='toto2', optype=models.OpType.categorical, dataType=models.DataType.boolean),
42-
models.DataField(name='toto3', optype=models.OpType.categorical, dataType=models.DataType.integer)])
43-
model = models.SimplePMMLRuleSetModel(dataDictionary=data_dictionary, ruleSetModel=None) # noqa
44-
45-
# when
46-
serialized = srz.serialize(model)
47-
res_data_dictionary_dict = xmltodict.parse(xml_input=serialized)['PMML']['DataDictionary']
48-
49-
# assert
50-
expected = '''
51-
<DataDictionary numberOfFields="4">
52-
<DataField name="toto0" optype="continuous" dataType="float"/>
53-
<DataField name="toto1" optype="ordinal" dataType="string"/>
54-
<DataField name="toto2" optype="categorical" dataType="boolean"/>
55-
<DataField name="toto3" optype="categorical" dataType="integer"/>
56-
</DataDictionary>
57-
'''
58-
self.assertEqual(res_data_dictionary_dict, xmltodict.parse(xml_input=expected)['DataDictionary'])
59-
60-
def test_serialize_mining_schema(self):
61-
# arrange
62-
srz = serializer.NyokaSerializer()
63-
mining_schema = models.MiningSchema(
64-
miningFields=[
65-
models.MiningField(name='toto0', usageType=models.MiningFieldUsageType.active),
66-
models.MiningField(name='toto1', usageType=models.MiningFieldUsageType.target)])
67-
model = models.SimplePMMLRuleSetModel(
68-
dataDictionary=models.DataDictionary(dataFields=None), # noqa
69-
ruleSetModel=models.RuleSetModel(miningSchema=mining_schema, ruleSet=None)) # noqa
70-
71-
# when
72-
serialized = srz.serialize(model)
73-
res_data_dictionary_dict = xmltodict.parse(xml_input=serialized)['PMML']['RuleSetModel']['MiningSchema']
74-
75-
# assert
76-
expected = '''
77-
<MiningSchema>
78-
<MiningField name="toto0" usageType="active"/>
79-
<MiningField name="toto1" usageType="target"/>
80-
</MiningSchema>
81-
'''
82-
self.assertEqual(res_data_dictionary_dict, xmltodict.parse(xml_input=expected)['MiningSchema'])
83-
8410
def test_serialize_rule_set(self):
8511
# arrange
8612
srz = serializer.NyokaSerializer()

0 commit comments

Comments
 (0)