Skip to content

Commit 34b2e02

Browse files
ke like li
authored andcommitted
Add tests/rule_induction/trxf/pmml_export/test_nyoka_serialize_integration.py
Signed-off-by: Ke LI <[email protected]>
1 parent 412658f commit 34b2e02

File tree

5 files changed

+151
-3
lines changed

5 files changed

+151
-3
lines changed

aix360/algorithms/rule_induction/trxf/pmml_export/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .predicate import CompoundPredicate
1010
from .predicate import Operator
1111
from .predicate import SimplePredicate
12+
from .predicate import TruePredicate
1213
from .rule import SimpleRule
1314
from .rule import DEFAULT_WEIGHT
1415
from .rule import DEFAULT_CONFIDENCE

aix360/algorithms/rule_induction/trxf/pmml_export/models/attribute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ class ComplexPartialScore:
2121
@dataclass(frozen=True)
2222
class Attribute:
2323
score: typing.Union[str, ComplexPartialScore]
24-
predicate: typing.Union[predicate.SimplePredicate, predicate.CompoundPredicate]
24+
predicate: typing.Union[predicate.SimplePredicate, predicate.CompoundPredicate, predicate.TruePredicate]

aix360/algorithms/rule_induction/trxf/pmml_export/models/predicate.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,19 @@ class Operator(enum.Enum):
1616
lessOrEqual = 3
1717
greaterThan = 4
1818
greaterOrEqual = 5
19+
isMissing = 6
1920

2021

2122
@dataclass(frozen=True)
2223
class SimplePredicate:
2324
operator: Operator = field()
24-
value: str = field()
2525
field: str = field()
26+
value: typing.Optional[str] = None
27+
28+
29+
@dataclass(frozen=True)
30+
class TruePredicate:
31+
pass
2632

2733

2834
# Use functional api to add aliases for `or` and `and`

aix360/algorithms/rule_induction/trxf/pmml_export/serializer/nyoka_serializer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,6 @@ def _nyoka_pmml_attributes(self, attribute: models.Attribute) -> nyoka_pmml.Attr
150150
booleanOperator=attribute.predicate.booleanOperator.name,
151151
SimplePredicate=[
152152
nyoka_pmml.SimplePredicate(field=sp.field, operator=sp.operator.name, value=sp.value)
153-
for sp in attribute.predicate.simplePredicates]))
153+
for sp in attribute.predicate.simplePredicates]),
154+
True_=None if (attribute.predicate is None or not isinstance(
155+
attribute.predicate, models.TruePredicate)) else nyoka_pmml.True_())
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import datetime
2+
import io
3+
from unittest import TestCase
4+
5+
try:
6+
import pypmml
7+
import pandas
8+
except ImportError:
9+
pypmml = None
10+
pandas = None
11+
12+
import aix360.algorithms.rule_induction.trxf.pmml_export.models as models
13+
import aix360.algorithms.rule_induction.trxf.pmml_export.serializer as serializer
14+
15+
16+
class TestNyokaSerializer(TestCase):
17+
nyokaSerializer = serializer.NyokaSerializer(datetime.datetime.now())
18+
19+
def setUp(self):
20+
if pypmml is None or pandas is None:
21+
self.skipTest('Install pypmml and pandas for integration tests')
22+
23+
def test_serialize_then_predict(self):
24+
# arrange
25+
score_card = models.Scorecard(
26+
models.DataDictionary(
27+
[
28+
models.DataField(
29+
name='department', dataType=models.DataType.string, optype=models.OpType.categorical),
30+
models.DataField(
31+
name='age', dataType=models.DataType.integer, optype=models.OpType.continuous),
32+
models.DataField(
33+
name='income', dataType=models.DataType.double, optype=models.OpType.continuous),
34+
models.DataField(
35+
name='overallScore', dataType=models.DataType.double, optype=models.OpType.continuous)
36+
]
37+
),
38+
miningSchema=models.MiningSchema(
39+
[
40+
models.MiningField(name='department', usageType=models.MiningFieldUsageType.active),
41+
models.MiningField(name='age', usageType=models.MiningFieldUsageType.active),
42+
models.MiningField(name='income', usageType=models.MiningFieldUsageType.active),
43+
models.MiningField(name='overallScore', usageType=models.MiningFieldUsageType.target),
44+
]
45+
),
46+
output=models.Output([
47+
models.OutputField(
48+
name='Final Score',
49+
feature='predictedValue',
50+
dataType=models.DataType.double,
51+
optype=models.OpType.continuous)
52+
]),
53+
characteristics=models.Characteristics(
54+
[
55+
models.Characteristic(name='departmentScore', attributes=[
56+
models.Attribute(
57+
score='-9',
58+
predicate=models.SimplePredicate(
59+
field='department', operator=models.Operator.isMissing)),
60+
models.Attribute(
61+
score='19',
62+
predicate=models.SimplePredicate(
63+
field='department', operator=models.Operator.equal, value='marketing')),
64+
models.Attribute(
65+
score='3',
66+
predicate=models.SimplePredicate(
67+
field='department', operator=models.Operator.equal, value='engineering')),
68+
models.Attribute(
69+
score='6',
70+
predicate=models.SimplePredicate(
71+
field='department', operator=models.Operator.equal, value='business')),
72+
models.Attribute(
73+
score='0',
74+
predicate=models.TruePredicate()),
75+
76+
]),
77+
models.Characteristic(
78+
name='ageScore', attributes=[
79+
models.Attribute(
80+
score='-1',
81+
predicate=models.SimplePredicate(
82+
field='age', operator=models.Operator.isMissing)),
83+
models.Attribute(
84+
score='-3',
85+
predicate=models.SimplePredicate(
86+
field='age', operator=models.Operator.lessOrEqual, value='18')),
87+
models.Attribute(
88+
score='0',
89+
predicate=models.CompoundPredicate(
90+
booleanOperator=models.BooleanOperator.and_,
91+
simplePredicates=[
92+
models.SimplePredicate(
93+
field='age', operator=models.Operator.greaterThan, value='18'),
94+
models.SimplePredicate(
95+
field='age', operator=models.Operator.lessOrEqual, value='29')])),
96+
models.Attribute(
97+
score='12',
98+
predicate=models.CompoundPredicate(
99+
booleanOperator=models.BooleanOperator.and_,
100+
simplePredicates=[
101+
models.SimplePredicate(
102+
field='age', operator=models.Operator.greaterThan, value='29'),
103+
models.SimplePredicate(
104+
field='age', operator=models.Operator.lessOrEqual, value='39')])),
105+
models.Attribute(
106+
score='18',
107+
predicate=models.SimplePredicate(
108+
field='age', operator=models.Operator.greaterThan, value='39'))]),
109+
models.Characteristic(name='incomeScore', attributes=[
110+
models.Attribute(
111+
score='3',
112+
predicate=models.SimplePredicate(
113+
field='income', operator=models.Operator.isMissing)),
114+
models.Attribute(
115+
predicate=models.SimplePredicate(
116+
field='income', operator=models.Operator.equal, value='1000'),
117+
score=models.ComplexPartialScore(feature_name='income', multiplier='0.03', constant='11')),
118+
models.Attribute(
119+
score='5',
120+
predicate=models.CompoundPredicate(
121+
booleanOperator=models.BooleanOperator.and_,
122+
simplePredicates=[
123+
models.SimplePredicate(
124+
field='income', operator=models.Operator.greaterThan, value='1000'),
125+
models.SimplePredicate(
126+
field='income', operator=models.Operator.lessOrEqual, value='2500')])),
127+
models.Attribute(
128+
predicate=models.SimplePredicate(
129+
field='income', operator=models.Operator.greaterThan, value='1500'),
130+
score=models.ComplexPartialScore(
131+
feature_name='income', multiplier='0.01', constant='18'))])]))
132+
133+
# when
134+
serialized = self.nyokaSerializer.serialize(score_card)
135+
pmml_model = pypmml.Model.load(io.StringIO(serialized))
136+
137+
# assert
138+
self.assertIsNotNone(pmml_model)
139+
self.assertEqual(len(pmml_model.dataDictionary.fields), 4)

0 commit comments

Comments
 (0)