diff --git a/library/src/main/java/org/apache/apex/malhar/lib/window/accumulation/PojoInnerJoin.java b/library/src/main/java/org/apache/apex/malhar/lib/window/accumulation/PojoInnerJoin.java index 1872d19db3..a5b1117a68 100644 --- a/library/src/main/java/org/apache/apex/malhar/lib/window/accumulation/PojoInnerJoin.java +++ b/library/src/main/java/org/apache/apex/malhar/lib/window/accumulation/PojoInnerJoin.java @@ -27,9 +27,13 @@ import java.util.Map; import org.apache.apex.malhar.lib.window.MergeAccumulation; +import org.apache.commons.lang3.ClassUtils; import com.google.common.base.Throwables; +import com.datatorrent.lib.util.KeyValPair; +import com.datatorrent.lib.util.PojoUtils; + /** * Inner join Accumulation for Pojo Streams. * @@ -40,6 +44,9 @@ public class PojoInnerJoin { protected final String[] keys; protected final Class outClass; + private transient List> gettersStream1; + private transient List> gettersStream2; + private transient List> setters; public PojoInnerJoin() { @@ -57,9 +64,35 @@ public PojoInnerJoin(int num, Class outClass, String... keys) this.outClass = outClass; } + private void createSetters() + { + Field[] fields = outClass.getDeclaredFields(); + setters = new ArrayList<>(); + for (Field field : fields) { + Class outputField = ClassUtils.primitiveToWrapper(field.getType()); + String fieldName = field.getName(); + setters.add(new KeyValPair<>(fieldName,PojoUtils.createSetter(outClass,fieldName,outputField))); + } + } + + private List> createGetters(Class input) + { + Field[] fields = input.getDeclaredFields(); + List> getters = new ArrayList<>(); + for (Field field : fields) { + Class inputField = ClassUtils.primitiveToWrapper(field.getType()); + String fieldName = field.getName(); + getters.add(new KeyValPair<>(fieldName,PojoUtils.createGetter(input, fieldName, inputField))); + } + return getters; + } + @Override public List>> accumulate(List>> accumulatedValue, InputT1 input) { + if (gettersStream1 == null) { + gettersStream1 = createGetters(input.getClass()); + } try { return accumulateWithIndex(0, accumulatedValue, input); } catch (NoSuchFieldException e) { @@ -70,6 +103,9 @@ public List>> accumulate(List> @Override public List>> accumulate2(List>> accumulatedValue, InputT2 input) { + if (gettersStream2 == null) { + gettersStream2 = createGetters(input.getClass()); + } try { return accumulateWithIndex(1, accumulatedValue, input); } catch (NoSuchFieldException e) { @@ -96,27 +132,23 @@ private List>> accumulateWithIndex(int index, List> curList = accu.get(index); - Map map = pojoToMap(input); + Map map = pojoToMap(input,index + 1); curList.add(map); accu.set(index, curList); return accu; } - private Map pojoToMap(Object input) + private Map pojoToMap(Object input, int streamIndex) { Map map = new HashMap<>(); + List> gettersStream = streamIndex == 1 ? gettersStream1 : gettersStream2; - Field[] fields = input.getClass().getDeclaredFields(); - - for (Field field : fields) { - String[] words = field.getName().split("\\."); - String fieldName = words[words.length - 1]; - field.setAccessible(true); + for (KeyValPair getter : gettersStream) { try { - Object value = field.get(input); - map.put(fieldName, value); - } catch (IllegalAccessException e) { + Object value = getter.getValue().get(input); + map.put(getter.getKey(), value); + } catch (Exception e) { throw Throwables.propagate(e); } } @@ -142,6 +174,10 @@ public List getOutput(List>> accumulatedValue) // TODO: May need to revisit (use state manager). result = getAllCombo(0, accumulatedValue, result, null); + if (setters == null) { + createSetters(); + } + List out = new ArrayList<>(); for (Map resultMap : result) { Object o; @@ -150,16 +186,8 @@ public List getOutput(List>> accumulatedValue) } catch (Throwable e) { throw Throwables.propagate(e); } - - for (Map.Entry entry : resultMap.entrySet()) { - Field f; - try { - f = outClass.getDeclaredField(entry.getKey()); - f.setAccessible(true); - f.set(o, entry.getValue()); - } catch (NoSuchFieldException | IllegalAccessException e) { - throw Throwables.propagate(e); - } + for (KeyValPair setter : setters) { + setter.getValue().set(o,resultMap.get(setter.getKey())); } out.add(o); } diff --git a/library/src/test/java/org/apache/apex/malhar/lib/window/accumulation/PojoInnerJoinTest.java b/library/src/test/java/org/apache/apex/malhar/lib/window/accumulation/PojoInnerJoinTest.java index 47ce815312..47c7307d7c 100644 --- a/library/src/test/java/org/apache/apex/malhar/lib/window/accumulation/PojoInnerJoinTest.java +++ b/library/src/test/java/org/apache/apex/malhar/lib/window/accumulation/PojoInnerJoinTest.java @@ -47,22 +47,22 @@ public TestPojo1(int id, String name) this.uName = name; } - public int getuId() + public int getUId() { return uId; } - public void setuId(int uId) + public void setUId(int uId) { this.uId = uId; } - public String getuName() + public String getUName() { return uName; } - public void setuName(String uName) + public void setUName(String uName) { this.uName = uName; } @@ -84,12 +84,12 @@ public TestPojo2(int id, String dep) this.dep = dep; } - public int getuId() + public int getUId() { return uId; } - public void setuId(int uId) + public void setUId(int uId) { this.uId = uId; } @@ -111,22 +111,22 @@ public static class TestOutClass private String uName; private String dep; - public int getuId() + public int getUId() { return uId; } - public void setuId(int uId) + public void setUId(int uId) { this.uId = uId; } - public String getuName() + public String getUName() { return uName; } - public void setuName(String uName) + public void setUName(String uName) { this.uName = uName; } @@ -168,8 +168,8 @@ public void PojoInnerJoinTest() Object o = pij.getOutput(accu).get(0); Assert.assertTrue(o instanceof TestOutClass); TestOutClass testOutClass = (TestOutClass)o; - Assert.assertEquals(1, testOutClass.getuId()); - Assert.assertEquals("Josh", testOutClass.getuName()); + Assert.assertEquals(1, testOutClass.getUId()); + Assert.assertEquals("Josh", testOutClass.getUName()); Assert.assertEquals("CS", testOutClass.getDep()); } } diff --git a/library/src/test/java/org/apache/apex/malhar/lib/window/impl/PojoInnerJoinTestApplication.java b/library/src/test/java/org/apache/apex/malhar/lib/window/impl/PojoInnerJoinTestApplication.java index 809023b166..3d2e82a943 100644 --- a/library/src/test/java/org/apache/apex/malhar/lib/window/impl/PojoInnerJoinTestApplication.java +++ b/library/src/test/java/org/apache/apex/malhar/lib/window/impl/PojoInnerJoinTestApplication.java @@ -292,6 +292,76 @@ public void setTimestamp(long timestamp) } } + public static class OutputEvent + { + public int customerId; + public int productId; + public int productCategory; + public long timestamp; + public double amount; + public long timestamps; + + public int getCustomerId() + { + return customerId; + } + + public void setCustomerId(int customerId) + { + this.customerId = customerId; + } + + public int getProductId() + { + return productId; + } + + public void setProductId(int productId) + { + this.productId = productId; + } + + public int getProductCategory() + { + return productCategory; + } + + public void setProductCategory(int productCategory) + { + this.productCategory = productCategory; + } + + public long getTimestamp() + { + return timestamp; + } + + public void setTimestamp(long timestamp) + { + this.timestamp = timestamp; + } + + public double getAmount() + { + return amount; + } + + public void setAmount(double amount) + { + this.amount = amount; + } + + public long getTimestamps() + { + return timestamps; + } + + public void setTimestamps(long timestamp) + { + this.timestamps = timestamp; + } + } + public int getMaxProductId() { return maxProductId; @@ -375,7 +445,7 @@ public void populateDAG(DAG dag, Configuration conf) productGenerator.setSalesEvent(false); WindowedMergeOperatorImpl>, List>> op = dag.addOperator("Merge", new WindowedMergeOperatorImpl>, List>>()); - op.setAccumulation(new PojoInnerJoin(2, Object.class, "productId","productId")); + op.setAccumulation(new PojoInnerJoin(2, POJOGenerator.OutputEvent.class, "productId","productId")); op.setDataStorage(new InMemoryWindowedStorage>>()); WindowedStorage.WindowedPlainStorage windowStateMap = new InMemoryWindowedStorage<>();